saluki_io/
compression.rs

1use std::{
2    io,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use async_compression::{
8    tokio::write::{GzipEncoder, ZlibEncoder, ZstdEncoder},
9    Level,
10};
11use http::HeaderValue;
12use pin_project::pin_project;
13use tokio::io::AsyncWrite;
14use tracing::trace;
15
16// "Red zone" threshold factor.
17//
18// See `CompressionEstimator::would_write_exceed_threshold` for details.
19const THRESHOLD_RED_ZONE: f64 = 0.99;
20
21static CONTENT_ENCODING_DEFLATE: HeaderValue = HeaderValue::from_static("deflate");
22static CONTENT_ENCODING_GZIP: HeaderValue = HeaderValue::from_static("gzip");
23static CONTENT_ENCODING_ZSTD: HeaderValue = HeaderValue::from_static("zstd");
24
25/// Compression schemes supported by `Compressor`.
26#[derive(Copy, Clone, Debug)]
27pub enum CompressionScheme {
28    /// No compression.
29    Noop,
30    /// Gzip.
31    Gzip(Level),
32    /// Zlib.
33    Zlib(Level),
34    /// Zstd.
35    Zstd(Level),
36}
37
38impl CompressionScheme {
39    /// No compression.
40    pub const fn noop() -> Self {
41        Self::Noop
42    }
43
44    /// Gzip compression, using the default compression level (6).
45    pub const fn gzip_default() -> Self {
46        Self::Gzip(Level::Default)
47    }
48
49    /// Zlib compression, using the default compression level (6).
50    pub const fn zlib_default() -> Self {
51        Self::Zlib(Level::Default)
52    }
53
54    /// Zstd compression, using the default compression level (3).
55    pub const fn zstd_default() -> Self {
56        Self::Zstd(Level::Default)
57    }
58
59    /// Create a new compression scheme from a string and level.
60    ///
61    /// Level is only used if the scheme is `gzip` or `zstd`.
62    ///
63    /// Defaults to zstd with level 3.
64    pub fn new(scheme: &str, level: i32) -> Self {
65        match scheme {
66            "gzip" => Self::Gzip(Level::Precise(level)),
67            "zlib" => CompressionScheme::zlib_default(),
68            "zstd" => Self::Zstd(Level::Precise(level)),
69            _ => Self::Zstd(Level::Default),
70        }
71    }
72}
73
74#[pin_project]
75pub struct CountingWriter<W> {
76    #[pin]
77    inner: W,
78    total_written: u64,
79}
80
81impl<W> CountingWriter<W> {
82    fn new(inner: W) -> Self {
83        Self {
84            inner,
85            total_written: 0,
86        }
87    }
88
89    fn total_written(&self) -> u64 {
90        self.total_written
91    }
92
93    fn into_inner(self) -> W {
94        self.inner
95    }
96}
97
98/// Statistics for a writer.
99pub trait WriteStatistics {
100    /// Returns the total number of bytes written.
101    fn total_written(&self) -> u64;
102}
103
104impl<W: AsyncWrite> AsyncWrite for CountingWriter<W> {
105    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
106        let mut this = self.project();
107        this.inner.as_mut().poll_write(cx, buf).map(|result| {
108            if let Ok(written) = &result {
109                *this.total_written += *written as u64;
110            }
111
112            result
113        })
114    }
115
116    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
117        self.project().inner.poll_flush(cx)
118    }
119
120    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
121        self.project().inner.poll_shutdown(cx)
122    }
123}
124
125/// Generic compressor.
126///
127/// Exposes a semi-type-erased compression stream, by allowing the compression to be configured via `CompressionScheme`,
128/// and generically wrapping over a given writer.
129#[pin_project(project = CompressorProjected)]
130pub enum Compressor<W: AsyncWrite> {
131    /// No-op compressor.
132    Noop(#[pin] CountingWriter<W>),
133    /// Gzip compressor.
134    Gzip(#[pin] GzipEncoder<CountingWriter<W>>),
135    /// Zlib compressor.
136    Zlib(#[pin] ZlibEncoder<W>),
137    /// Zstd compressor.
138    Zstd(#[pin] ZstdEncoder<CountingWriter<W>>),
139}
140
141impl<W: AsyncWrite> Compressor<W> {
142    /// Creates a new compressor from a given compression scheme and writer.
143    pub fn from_scheme(scheme: CompressionScheme, writer: W) -> Self {
144        match scheme {
145            CompressionScheme::Noop => Self::Noop(CountingWriter::new(writer)),
146            CompressionScheme::Gzip(level) => Self::Gzip(GzipEncoder::with_quality(CountingWriter::new(writer), level)),
147            CompressionScheme::Zlib(level) => Self::Zlib(ZlibEncoder::with_quality(writer, level)),
148            CompressionScheme::Zstd(level) => Self::Zstd(ZstdEncoder::with_quality(CountingWriter::new(writer), level)),
149        }
150    }
151
152    /// Consumes the compressor, returning the inner writer.
153    pub fn into_inner(self) -> W {
154        match self {
155            Self::Noop(encoder) => encoder.into_inner(),
156            Self::Gzip(encoder) => encoder.into_inner().into_inner(),
157            Self::Zlib(encoder) => encoder.into_inner(),
158            Self::Zstd(encoder) => encoder.into_inner().into_inner(),
159        }
160    }
161
162    /// Returns the content encoding for this compressor.
163    pub fn content_encoding(&self) -> Option<HeaderValue> {
164        match self {
165            Self::Noop(_) => None,
166            Self::Gzip(_) => Some(CONTENT_ENCODING_GZIP.clone()),
167            Self::Zlib(_) => Some(CONTENT_ENCODING_DEFLATE.clone()),
168            Self::Zstd(_) => Some(CONTENT_ENCODING_ZSTD.clone()),
169        }
170    }
171}
172
173impl<W: AsyncWrite> AsyncWrite for Compressor<W> {
174    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
175        match self.project() {
176            CompressorProjected::Noop(encoder) => encoder.poll_write(cx, buf),
177            CompressorProjected::Gzip(encoder) => encoder.poll_write(cx, buf),
178            CompressorProjected::Zlib(encoder) => encoder.poll_write(cx, buf),
179            CompressorProjected::Zstd(encoder) => encoder.poll_write(cx, buf),
180        }
181    }
182
183    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
184        match self.project() {
185            CompressorProjected::Noop(encoder) => encoder.poll_flush(cx),
186            CompressorProjected::Gzip(encoder) => encoder.poll_flush(cx),
187            CompressorProjected::Zlib(encoder) => encoder.poll_flush(cx),
188            CompressorProjected::Zstd(encoder) => encoder.poll_flush(cx),
189        }
190    }
191
192    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
193        match self.project() {
194            CompressorProjected::Noop(encoder) => encoder.poll_shutdown(cx),
195            CompressorProjected::Gzip(encoder) => encoder.poll_shutdown(cx),
196            CompressorProjected::Zlib(encoder) => encoder.poll_shutdown(cx),
197            CompressorProjected::Zstd(encoder) => encoder.poll_shutdown(cx),
198        }
199    }
200}
201
202impl<W: AsyncWrite> WriteStatistics for Compressor<W> {
203    fn total_written(&self) -> u64 {
204        match self {
205            Compressor::Noop(encoder) => encoder.total_written(),
206            Compressor::Gzip(encoder) => encoder.get_ref().total_written(),
207            Compressor::Zlib(encoder) => encoder.total_out(),
208            Compressor::Zstd(encoder) => encoder.get_ref().total_written(),
209        }
210    }
211}
212
213/// A streaming estimator for the size of compressed data.
214///
215/// For many compression algorithms, there is a large amount of buffering and state during compression. This allows
216/// compression algorithms to better compress data by finding patterns across the current and previous inputs, as well
217/// as amortize how often they write compressed data to the output stream, increasing the potential efficiency of the
218/// related function or system calls to do so.
219///
220/// However, this presents a problem when there is a need to ensure that the size of the compressed data does not exceed
221/// a certain threshold. As many inputs can be written to the compressor before the next chunk of compressed data is
222/// output, it is possible to write enough data that the compressed output exceeds the threshold. Further, many
223/// compression algorithms/implementations do not provide a way to query the size of the compressed data without
224/// expensive operations that either require doing multiple compression passes on different slices of the data, or early
225/// flushing of compressed data, potentially leading to abnormally low compression ratios.
226///
227/// This estimator provides a way to estimate the size of the compressed data by combining both the known size of data
228/// written to the compressor's output stream, as well as the inputs written to the compressor. We track the state
229/// changes of the compressor, observing when it writes compressed data to the output stream. We additionally track
230/// every write in terms of its uncompressed size. In combining the two, we estimate the worst-case size of the
231/// compressed data based on what we know has been compressed so far and what we've written since the last time the
232/// compressed flush to the output stream.
233///
234/// TODO: We should probably move this into `Compressor` itself, because it will also make it easier to do
235/// per-compression-algorithm tweaks to the estimation logic if that's a path we want to take, and it also would be
236/// cleaner and let us avoid any footguns around forgetting to update the necessary estimator state, etc.
237#[derive(Debug, Default)]
238pub struct CompressionEstimator {
239    in_flight_uncompressed_len: usize,
240    total_uncompressed_len: usize,
241    total_compressed_len: u64,
242    current_compression_ratio: f64,
243}
244
245impl CompressionEstimator {
246    /// Tracks a write to the compressor.
247    pub fn track_write<W>(&mut self, compressor: &W, uncompressed_len: usize)
248    where
249        W: WriteStatistics,
250    {
251        self.in_flight_uncompressed_len += uncompressed_len;
252        self.total_uncompressed_len += uncompressed_len;
253
254        let compressed_len = compressor.total_written();
255        let compressed_len_delta = (compressed_len - self.total_compressed_len) as usize;
256        if compressed_len_delta > 0 {
257            // We just observed the compressor flushing data, so we need to recalculate our compression ratio.
258            self.current_compression_ratio = compressed_len as f64 / self.total_uncompressed_len as f64;
259            self.total_compressed_len = compressed_len;
260            self.in_flight_uncompressed_len = 0;
261
262            trace!(
263                block_size = compressed_len_delta,
264                uncompressed_len = self.total_uncompressed_len,
265                compressed_len = self.total_compressed_len,
266                compression_ratio = self.current_compression_ratio,
267                "Compressor wrote block to output stream."
268            );
269        }
270    }
271
272    /// Resets the estimator.
273    pub fn reset(&mut self) {
274        self.in_flight_uncompressed_len = 0;
275        self.total_uncompressed_len = 0;
276        self.total_compressed_len = 0;
277        self.current_compression_ratio = 0.0;
278    }
279
280    /// Returns the estimated length of the compressor.
281    ///
282    /// This figure is the sum of the total bytes written by the compressor to the output stream and the number of
283    /// uncompressed bytes written to the compressor since the last time the compressor wrote to the output stream
284    /// when factoring in the estimated compression ratio over the overall output stream.
285    pub fn estimated_len(&self) -> usize {
286        let estimated_in_flight_compressed_len =
287            (self.in_flight_uncompressed_len as f64 * self.current_compression_ratio) as usize;
288
289        self.total_compressed_len as usize + estimated_in_flight_compressed_len
290    }
291
292    /// Estimates if writing `len` bytes to the compressor would cause the final compressed size to exceed `threshold`
293    /// bytes.
294    pub fn would_write_exceed_threshold(&self, len: usize, threshold: usize) -> bool {
295        // If the length of the data to be written exceeds the threshold, then it obviously would exceed the threshold.
296        if len > threshold {
297            return true;
298        }
299
300        // If we have yet to see any compressed data, we can't make a meaningful estimate, and this likely means that
301        // the compressor is still actively able to compress more data into the first block, which when eventually
302        // written, should never exceed the compressed size limit... so we choose to not block writes in this case.
303        if self.total_compressed_len == 0 {
304            return false;
305        }
306
307        // We adjust the given threshold down by a small amount to account for the fact that the final block written by
308        // the compressor has more variability in size than the rest, due to being more likely to be flushed before
309        // internal buffers are full and having the chance to most efficiently compress the data. Essentially, if we
310        // estimate that writing `len` more bytes would put our compressed length into the "red zone", then it's too
311        // risky to write those bytes.
312        //
313        // This is a bit of a fudge factor, but we arrived at the value through empirical testing with the regression
314        // detector benchmarks. Small enough to not have a major impact on payload size efficiency, but large enough to
315        // entirely get rid of compressed payload size limit violations.
316        let adjusted_threshold = (threshold as f64 * THRESHOLD_RED_ZONE) as usize;
317        self.estimated_len() + len > adjusted_threshold
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    struct MockCompressor {
326        current_uncompressed_len: u64,
327        total_uncompressed_len: usize,
328        compressed_len: u64,
329    }
330
331    impl MockCompressor {
332        fn new() -> Self {
333            MockCompressor {
334                current_uncompressed_len: 0,
335                total_uncompressed_len: 0,
336                compressed_len: 0,
337            }
338        }
339
340        fn write(&mut self, n: usize) {
341            self.current_uncompressed_len += n as u64;
342            self.total_uncompressed_len += n;
343        }
344
345        fn flush(&mut self, compression_ratio: f64) {
346            self.compressed_len += (self.current_uncompressed_len as f64 * compression_ratio) as u64;
347            self.current_uncompressed_len = 0;
348        }
349
350        fn total_uncompressed_len(&self) -> usize {
351            self.total_uncompressed_len
352        }
353    }
354
355    impl WriteStatistics for MockCompressor {
356        fn total_written(&self) -> u64 {
357            self.compressed_len
358        }
359    }
360
361    #[test]
362    fn compression_estimator_no_output() {
363        let estimator = CompressionEstimator::default();
364
365        assert!(!estimator.would_write_exceed_threshold(10, 100));
366        assert!(estimator.would_write_exceed_threshold(100, 90));
367    }
368
369    #[test]
370    fn compression_estimator_single_flush() {
371        const MAX_COMPRESSED_LEN: usize = 100;
372        const COMPRESSION_RATIO: f64 = 0.7;
373        const WRITE_LEN: usize = 50;
374
375        let mut estimator = CompressionEstimator::default();
376
377        // Create our mock compressor and do a basic write, and then flush, so that our estimator can get some data.
378        let mut compressor = MockCompressor::new();
379        assert!(!estimator.would_write_exceed_threshold(WRITE_LEN, MAX_COMPRESSED_LEN));
380
381        // Write 50 bytes with a compression ratio of 0.7, giving us 35 bytes compressed.
382        compressor.write(WRITE_LEN);
383        compressor.flush(COMPRESSION_RATIO);
384        assert_eq!(compressor.total_written(), 35);
385
386        estimator.track_write(&compressor, WRITE_LEN);
387
388        // We should be able to write 65 more bytes compressed, so 100 bytes uncompressed, given the compression ratio we have (0.7),
389        // would give us 70 bytes estimated.. which is over the threshold.
390        assert!(estimator.would_write_exceed_threshold(100, MAX_COMPRESSED_LEN));
391
392        // However, another 50 byte write would theoretically just be another 35 bytes compressed, so 85 bytes compressed total,
393        // which is under our threshold and should be allowed.
394        assert!(!estimator.would_write_exceed_threshold(WRITE_LEN, MAX_COMPRESSED_LEN));
395    }
396
397    #[test]
398    fn compression_estimator_multiple_flush_partial() {
399        const MAX_COMPRESSED_LEN: usize = 5000;
400        const FIRST_COMPRESSION_RATIO: f64 = 0.7;
401        const FIRST_WRITE_LEN: usize = 5000;
402        const SECOND_COMPRESSION_RATIO: f64 = 2.1;
403        const SECOND_WRITE_LEN: usize = 300;
404        const THIRD_WRITE_LEN: usize = 820;
405
406        let mut estimator = CompressionEstimator::default();
407
408        // Create our mock compressor and assert we can do our first write.
409        let mut compressor = MockCompressor::new();
410        assert!(!estimator.would_write_exceed_threshold(FIRST_WRITE_LEN, MAX_COMPRESSED_LEN));
411
412        // Write 5,000 bytes with a compression ratio of 0.7, giving us 3,500 bytes compressed.
413        compressor.write(FIRST_WRITE_LEN);
414        compressor.flush(FIRST_COMPRESSION_RATIO);
415        assert_eq!(compressor.total_uncompressed_len(), FIRST_WRITE_LEN);
416        assert_eq!(compressor.total_written(), 3500);
417
418        estimator.track_write(&compressor, FIRST_WRITE_LEN);
419
420        // We now do second write that simulates a "short" flush on the compressor: this might just be the compressor writing a partial block.
421        //
422        // What we want to test here is the estimator's ability to focus on the overall compression ratio rather than getting "lost" due to
423        // a single block being flushed which, when viewed naively, appears to be vastly bigger than the actual in-flight uncompressed data
424        // that it represents.
425        //
426        // We end up writing 300 bytes uncompressed with a compression ratio of 2.1, giving us 630 bytes compressed. We now have a total of
427        // 5,300 bytes uncompressed and 4,130 bytes compressed. Our overall compression ratio is now 0.77.
428        compressor.write(SECOND_WRITE_LEN);
429        compressor.flush(SECOND_COMPRESSION_RATIO);
430        assert_eq!(compressor.total_uncompressed_len(), FIRST_WRITE_LEN + SECOND_WRITE_LEN);
431        assert_eq!(compressor.total_written(), 4130);
432
433        estimator.track_write(&compressor, SECOND_WRITE_LEN);
434
435        // At this point, with our compressed limit of 5,000 bytes, we should be able to fit in another 870 bytes compressed. We do have to
436        // compensate for the "red zone" threshold, though, which should put us at a threshold of 4,950 bytes so 820 bytes compressed.
437        //
438        // We use the compressed length when calling `would_write_exceed_threshold` because it uses the uncompressed length as the worst-case
439        // scenario, which is that the write would not be compressed at all.
440        assert!(!estimator.would_write_exceed_threshold(THIRD_WRITE_LEN, MAX_COMPRESSED_LEN));
441    }
442}