1use std::{
2 io,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use async_compression::{
8 tokio::write::{GzipEncoder, ZlibEncoder, ZstdEncoder},
9 Level,
10};
11use http::HeaderValue;
12use pin_project::pin_project;
13use tokio::io::AsyncWrite;
14use tracing::trace;
15
16const THRESHOLD_RED_ZONE: f64 = 0.99;
20
21static CONTENT_ENCODING_DEFLATE: HeaderValue = HeaderValue::from_static("deflate");
22static CONTENT_ENCODING_GZIP: HeaderValue = HeaderValue::from_static("gzip");
23static CONTENT_ENCODING_ZSTD: HeaderValue = HeaderValue::from_static("zstd");
24
25#[derive(Copy, Clone, Debug)]
27pub enum CompressionScheme {
28 Noop,
30 Gzip(Level),
32 Zlib(Level),
34 Zstd(Level),
36}
37
38impl CompressionScheme {
39 pub const fn noop() -> Self {
41 Self::Noop
42 }
43
44 pub const fn gzip_default() -> Self {
46 Self::Gzip(Level::Default)
47 }
48
49 pub const fn zlib_default() -> Self {
51 Self::Zlib(Level::Default)
52 }
53
54 pub const fn zstd_default() -> Self {
56 Self::Zstd(Level::Default)
57 }
58
59 pub fn new(scheme: &str, level: i32) -> Self {
65 match scheme {
66 "gzip" => Self::Gzip(Level::Precise(level)),
67 "zlib" => CompressionScheme::zlib_default(),
68 "zstd" => Self::Zstd(Level::Precise(level)),
69 _ => Self::Zstd(Level::Default),
70 }
71 }
72}
73
74#[pin_project]
75pub struct CountingWriter<W> {
76 #[pin]
77 inner: W,
78 total_written: u64,
79}
80
81impl<W> CountingWriter<W> {
82 fn new(inner: W) -> Self {
83 Self {
84 inner,
85 total_written: 0,
86 }
87 }
88
89 fn total_written(&self) -> u64 {
90 self.total_written
91 }
92
93 fn into_inner(self) -> W {
94 self.inner
95 }
96}
97
98pub trait WriteStatistics {
100 fn total_written(&self) -> u64;
102}
103
104impl<W: AsyncWrite> AsyncWrite for CountingWriter<W> {
105 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
106 let mut this = self.project();
107 this.inner.as_mut().poll_write(cx, buf).map(|result| {
108 if let Ok(written) = &result {
109 *this.total_written += *written as u64;
110 }
111
112 result
113 })
114 }
115
116 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
117 self.project().inner.poll_flush(cx)
118 }
119
120 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
121 self.project().inner.poll_shutdown(cx)
122 }
123}
124
125#[pin_project(project = CompressorProjected)]
130pub enum Compressor<W: AsyncWrite> {
131 Noop(#[pin] CountingWriter<W>),
133 Gzip(#[pin] GzipEncoder<CountingWriter<W>>),
135 Zlib(#[pin] ZlibEncoder<W>),
137 Zstd(#[pin] ZstdEncoder<CountingWriter<W>>),
139}
140
141impl<W: AsyncWrite> Compressor<W> {
142 pub fn from_scheme(scheme: CompressionScheme, writer: W) -> Self {
144 match scheme {
145 CompressionScheme::Noop => Self::Noop(CountingWriter::new(writer)),
146 CompressionScheme::Gzip(level) => Self::Gzip(GzipEncoder::with_quality(CountingWriter::new(writer), level)),
147 CompressionScheme::Zlib(level) => Self::Zlib(ZlibEncoder::with_quality(writer, level)),
148 CompressionScheme::Zstd(level) => Self::Zstd(ZstdEncoder::with_quality(CountingWriter::new(writer), level)),
149 }
150 }
151
152 pub fn into_inner(self) -> W {
154 match self {
155 Self::Noop(encoder) => encoder.into_inner(),
156 Self::Gzip(encoder) => encoder.into_inner().into_inner(),
157 Self::Zlib(encoder) => encoder.into_inner(),
158 Self::Zstd(encoder) => encoder.into_inner().into_inner(),
159 }
160 }
161
162 pub fn content_encoding(&self) -> Option<HeaderValue> {
164 match self {
165 Self::Noop(_) => None,
166 Self::Gzip(_) => Some(CONTENT_ENCODING_GZIP.clone()),
167 Self::Zlib(_) => Some(CONTENT_ENCODING_DEFLATE.clone()),
168 Self::Zstd(_) => Some(CONTENT_ENCODING_ZSTD.clone()),
169 }
170 }
171}
172
173impl<W: AsyncWrite> AsyncWrite for Compressor<W> {
174 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
175 match self.project() {
176 CompressorProjected::Noop(encoder) => encoder.poll_write(cx, buf),
177 CompressorProjected::Gzip(encoder) => encoder.poll_write(cx, buf),
178 CompressorProjected::Zlib(encoder) => encoder.poll_write(cx, buf),
179 CompressorProjected::Zstd(encoder) => encoder.poll_write(cx, buf),
180 }
181 }
182
183 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
184 match self.project() {
185 CompressorProjected::Noop(encoder) => encoder.poll_flush(cx),
186 CompressorProjected::Gzip(encoder) => encoder.poll_flush(cx),
187 CompressorProjected::Zlib(encoder) => encoder.poll_flush(cx),
188 CompressorProjected::Zstd(encoder) => encoder.poll_flush(cx),
189 }
190 }
191
192 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
193 match self.project() {
194 CompressorProjected::Noop(encoder) => encoder.poll_shutdown(cx),
195 CompressorProjected::Gzip(encoder) => encoder.poll_shutdown(cx),
196 CompressorProjected::Zlib(encoder) => encoder.poll_shutdown(cx),
197 CompressorProjected::Zstd(encoder) => encoder.poll_shutdown(cx),
198 }
199 }
200}
201
202impl<W: AsyncWrite> WriteStatistics for Compressor<W> {
203 fn total_written(&self) -> u64 {
204 match self {
205 Compressor::Noop(encoder) => encoder.total_written(),
206 Compressor::Gzip(encoder) => encoder.get_ref().total_written(),
207 Compressor::Zlib(encoder) => encoder.total_out(),
208 Compressor::Zstd(encoder) => encoder.get_ref().total_written(),
209 }
210 }
211}
212
213#[derive(Debug, Default)]
238pub struct CompressionEstimator {
239 in_flight_uncompressed_len: usize,
240 total_uncompressed_len: usize,
241 total_compressed_len: u64,
242 current_compression_ratio: f64,
243}
244
245impl CompressionEstimator {
246 pub fn track_write<W>(&mut self, compressor: &W, uncompressed_len: usize)
248 where
249 W: WriteStatistics,
250 {
251 self.in_flight_uncompressed_len += uncompressed_len;
252 self.total_uncompressed_len += uncompressed_len;
253
254 let compressed_len = compressor.total_written();
255 let compressed_len_delta = (compressed_len - self.total_compressed_len) as usize;
256 if compressed_len_delta > 0 {
257 self.current_compression_ratio = compressed_len as f64 / self.total_uncompressed_len as f64;
259 self.total_compressed_len = compressed_len;
260 self.in_flight_uncompressed_len = 0;
261
262 trace!(
263 block_size = compressed_len_delta,
264 uncompressed_len = self.total_uncompressed_len,
265 compressed_len = self.total_compressed_len,
266 compression_ratio = self.current_compression_ratio,
267 "Compressor wrote block to output stream."
268 );
269 }
270 }
271
272 pub fn reset(&mut self) {
274 self.in_flight_uncompressed_len = 0;
275 self.total_uncompressed_len = 0;
276 self.total_compressed_len = 0;
277 self.current_compression_ratio = 0.0;
278 }
279
280 pub fn estimated_len(&self) -> usize {
286 let estimated_in_flight_compressed_len =
287 (self.in_flight_uncompressed_len as f64 * self.current_compression_ratio) as usize;
288
289 self.total_compressed_len as usize + estimated_in_flight_compressed_len
290 }
291
292 pub fn would_write_exceed_threshold(&self, len: usize, threshold: usize) -> bool {
295 if len > threshold {
297 return true;
298 }
299
300 if self.total_compressed_len == 0 {
304 return false;
305 }
306
307 let adjusted_threshold = (threshold as f64 * THRESHOLD_RED_ZONE) as usize;
317 self.estimated_len() + len > adjusted_threshold
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 struct MockCompressor {
326 current_uncompressed_len: u64,
327 total_uncompressed_len: usize,
328 compressed_len: u64,
329 }
330
331 impl MockCompressor {
332 fn new() -> Self {
333 MockCompressor {
334 current_uncompressed_len: 0,
335 total_uncompressed_len: 0,
336 compressed_len: 0,
337 }
338 }
339
340 fn write(&mut self, n: usize) {
341 self.current_uncompressed_len += n as u64;
342 self.total_uncompressed_len += n;
343 }
344
345 fn flush(&mut self, compression_ratio: f64) {
346 self.compressed_len += (self.current_uncompressed_len as f64 * compression_ratio) as u64;
347 self.current_uncompressed_len = 0;
348 }
349
350 fn total_uncompressed_len(&self) -> usize {
351 self.total_uncompressed_len
352 }
353 }
354
355 impl WriteStatistics for MockCompressor {
356 fn total_written(&self) -> u64 {
357 self.compressed_len
358 }
359 }
360
361 #[test]
362 fn compression_estimator_no_output() {
363 let estimator = CompressionEstimator::default();
364
365 assert!(!estimator.would_write_exceed_threshold(10, 100));
366 assert!(estimator.would_write_exceed_threshold(100, 90));
367 }
368
369 #[test]
370 fn compression_estimator_single_flush() {
371 const MAX_COMPRESSED_LEN: usize = 100;
372 const COMPRESSION_RATIO: f64 = 0.7;
373 const WRITE_LEN: usize = 50;
374
375 let mut estimator = CompressionEstimator::default();
376
377 let mut compressor = MockCompressor::new();
379 assert!(!estimator.would_write_exceed_threshold(WRITE_LEN, MAX_COMPRESSED_LEN));
380
381 compressor.write(WRITE_LEN);
383 compressor.flush(COMPRESSION_RATIO);
384 assert_eq!(compressor.total_written(), 35);
385
386 estimator.track_write(&compressor, WRITE_LEN);
387
388 assert!(estimator.would_write_exceed_threshold(100, MAX_COMPRESSED_LEN));
391
392 assert!(!estimator.would_write_exceed_threshold(WRITE_LEN, MAX_COMPRESSED_LEN));
395 }
396
397 #[test]
398 fn compression_estimator_multiple_flush_partial() {
399 const MAX_COMPRESSED_LEN: usize = 5000;
400 const FIRST_COMPRESSION_RATIO: f64 = 0.7;
401 const FIRST_WRITE_LEN: usize = 5000;
402 const SECOND_COMPRESSION_RATIO: f64 = 2.1;
403 const SECOND_WRITE_LEN: usize = 300;
404 const THIRD_WRITE_LEN: usize = 820;
405
406 let mut estimator = CompressionEstimator::default();
407
408 let mut compressor = MockCompressor::new();
410 assert!(!estimator.would_write_exceed_threshold(FIRST_WRITE_LEN, MAX_COMPRESSED_LEN));
411
412 compressor.write(FIRST_WRITE_LEN);
414 compressor.flush(FIRST_COMPRESSION_RATIO);
415 assert_eq!(compressor.total_uncompressed_len(), FIRST_WRITE_LEN);
416 assert_eq!(compressor.total_written(), 3500);
417
418 estimator.track_write(&compressor, FIRST_WRITE_LEN);
419
420 compressor.write(SECOND_WRITE_LEN);
429 compressor.flush(SECOND_COMPRESSION_RATIO);
430 assert_eq!(compressor.total_uncompressed_len(), FIRST_WRITE_LEN + SECOND_WRITE_LEN);
431 assert_eq!(compressor.total_written(), 4130);
432
433 estimator.track_write(&compressor, SECOND_WRITE_LEN);
434
435 assert!(!estimator.would_write_exceed_threshold(THIRD_WRITE_LEN, MAX_COMPRESSED_LEN));
441 }
442}