1use std::{
2 io,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use async_compression::{
8 tokio::write::{GzipEncoder, ZlibEncoder, ZstdEncoder},
9 Level,
10};
11use http::HeaderValue;
12use pin_project::pin_project;
13use tokio::io::AsyncWrite;
14use tracing::trace;
15
16const THRESHOLD_RED_ZONE: f64 = 0.99;
20
21static CONTENT_ENCODING_DEFLATE: HeaderValue = HeaderValue::from_static("deflate");
22static CONTENT_ENCODING_GZIP: HeaderValue = HeaderValue::from_static("gzip");
23static CONTENT_ENCODING_ZSTD: HeaderValue = HeaderValue::from_static("zstd");
24
25#[derive(Copy, Clone, Debug)]
27pub enum CompressionScheme {
28 Noop,
30 Gzip(Level),
32 Zlib(Level),
34 Zstd(Level),
36}
37
38impl CompressionScheme {
39 pub const fn noop() -> Self {
41 Self::Noop
42 }
43
44 pub const fn gzip_default() -> Self {
46 Self::Gzip(Level::Default)
47 }
48
49 pub const fn zlib_default() -> Self {
51 Self::Zlib(Level::Default)
52 }
53
54 pub const fn zstd_default() -> Self {
56 Self::Zstd(Level::Default)
57 }
58
59 pub fn new(scheme: &str, level: i32) -> Self {
65 match scheme {
66 "gzip" => Self::Gzip(Level::Precise(level)),
67 "zlib" => CompressionScheme::zlib_default(),
68 "zstd" => Self::Zstd(Level::Precise(level)),
69 _ => Self::Zstd(Level::Default),
70 }
71 }
72}
73
74#[pin_project]
75pub struct CountingWriter<W> {
76 #[pin]
77 inner: W,
78 total_written: u64,
79}
80
81impl<W> CountingWriter<W> {
82 fn new(inner: W) -> Self {
83 Self {
84 inner,
85 total_written: 0,
86 }
87 }
88
89 fn total_written(&self) -> u64 {
90 self.total_written
91 }
92
93 fn into_inner(self) -> W {
94 self.inner
95 }
96}
97
98pub trait WriteStatistics {
100 fn total_written(&self) -> u64;
102}
103
104impl<W: AsyncWrite> AsyncWrite for CountingWriter<W> {
105 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
106 let mut this = self.project();
107 this.inner.as_mut().poll_write(cx, buf).map(|result| {
108 if let Ok(written) = &result {
109 *this.total_written += *written as u64;
110 }
111
112 result
113 })
114 }
115
116 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
117 self.project().inner.poll_flush(cx)
118 }
119
120 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
121 self.project().inner.poll_shutdown(cx)
122 }
123}
124
125#[pin_project(project = CompressorProjected)]
130pub enum Compressor<W: AsyncWrite> {
131 Noop(#[pin] CountingWriter<W>),
133 Gzip(#[pin] GzipEncoder<CountingWriter<W>>),
135 Zlib(#[pin] ZlibEncoder<W>),
137 Zstd(#[pin] ZstdEncoder<CountingWriter<W>>),
139}
140
141impl<W: AsyncWrite> Compressor<W> {
142 pub fn from_scheme(scheme: CompressionScheme, writer: W) -> Self {
144 match scheme {
145 CompressionScheme::Noop => Self::Noop(CountingWriter::new(writer)),
146 CompressionScheme::Gzip(level) => Self::Gzip(GzipEncoder::with_quality(CountingWriter::new(writer), level)),
147 CompressionScheme::Zlib(level) => Self::Zlib(ZlibEncoder::with_quality(writer, level)),
148 CompressionScheme::Zstd(level) => Self::Zstd(ZstdEncoder::with_quality(CountingWriter::new(writer), level)),
149 }
150 }
151
152 pub fn into_inner(self) -> W {
154 match self {
155 Self::Noop(encoder) => encoder.into_inner(),
156 Self::Gzip(encoder) => encoder.into_inner().into_inner(),
157 Self::Zlib(encoder) => encoder.into_inner(),
158 Self::Zstd(encoder) => encoder.into_inner().into_inner(),
159 }
160 }
161
162 pub fn content_encoding(&self) -> Option<HeaderValue> {
164 match self {
165 Self::Noop(_) => None,
166 Self::Gzip(_) => Some(CONTENT_ENCODING_GZIP.clone()),
167 Self::Zlib(_) => Some(CONTENT_ENCODING_DEFLATE.clone()),
168 Self::Zstd(_) => Some(CONTENT_ENCODING_ZSTD.clone()),
169 }
170 }
171}
172
173impl<W: AsyncWrite> AsyncWrite for Compressor<W> {
174 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
175 match self.project() {
176 CompressorProjected::Noop(encoder) => encoder.poll_write(cx, buf),
177 CompressorProjected::Gzip(encoder) => encoder.poll_write(cx, buf),
178 CompressorProjected::Zlib(encoder) => encoder.poll_write(cx, buf),
179 CompressorProjected::Zstd(encoder) => encoder.poll_write(cx, buf),
180 }
181 }
182
183 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
184 match self.project() {
185 CompressorProjected::Noop(encoder) => encoder.poll_flush(cx),
186 CompressorProjected::Gzip(encoder) => encoder.poll_flush(cx),
187 CompressorProjected::Zlib(encoder) => encoder.poll_flush(cx),
188 CompressorProjected::Zstd(encoder) => encoder.poll_flush(cx),
189 }
190 }
191
192 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
193 match self.project() {
194 CompressorProjected::Noop(encoder) => encoder.poll_shutdown(cx),
195 CompressorProjected::Gzip(encoder) => encoder.poll_shutdown(cx),
196 CompressorProjected::Zlib(encoder) => encoder.poll_shutdown(cx),
197 CompressorProjected::Zstd(encoder) => encoder.poll_shutdown(cx),
198 }
199 }
200}
201
202impl<W: AsyncWrite> WriteStatistics for Compressor<W> {
203 fn total_written(&self) -> u64 {
204 match self {
205 Compressor::Noop(encoder) => encoder.total_written(),
206 Compressor::Gzip(encoder) => encoder.get_ref().total_written(),
207 Compressor::Zlib(encoder) => encoder.total_out(),
208 Compressor::Zstd(encoder) => encoder.get_ref().total_written(),
209 }
210 }
211}
212
213#[derive(Debug, Default)]
238pub struct CompressionEstimator {
239 in_flight_uncompressed_len: usize,
240 total_uncompressed_len: usize,
241 total_compressed_len: u64,
242 current_compression_ratio: f64,
243}
244
245impl CompressionEstimator {
246 pub fn track_write<W>(&mut self, compressor: &W, uncompressed_len: usize)
248 where
249 W: WriteStatistics,
250 {
251 self.in_flight_uncompressed_len += uncompressed_len;
252 self.total_uncompressed_len += uncompressed_len;
253
254 let compressed_len = compressor.total_written();
255 let compressed_len_delta = (compressed_len - self.total_compressed_len) as usize;
256 if compressed_len_delta > 0 {
257 self.current_compression_ratio = compressed_len as f64 / self.total_uncompressed_len as f64;
259 self.total_compressed_len = compressed_len;
260 self.in_flight_uncompressed_len = 0;
261
262 trace!(
263 block_size = compressed_len_delta,
264 uncompressed_len = self.total_uncompressed_len,
265 compressed_len = self.total_compressed_len,
266 compression_ratio = self.current_compression_ratio,
267 "Compressor wrote block to output stream."
268 );
269 }
270 }
271
272 pub fn reset(&mut self) {
274 self.in_flight_uncompressed_len = 0;
275 self.total_uncompressed_len = 0;
276 self.total_compressed_len = 0;
277 self.current_compression_ratio = 0.0;
278 }
279
280 pub fn estimated_len(&self) -> usize {
286 let estimated_in_flight_compressed_len =
287 (self.in_flight_uncompressed_len as f64 * self.current_compression_ratio) as usize;
288
289 self.total_compressed_len as usize + estimated_in_flight_compressed_len
290 }
291
292 pub fn would_write_exceed_threshold(&self, len: usize, threshold: usize) -> bool {
295 if self.total_compressed_len == 0 {
299 return false;
300 }
301
302 let adjusted_threshold = (threshold as f64 * THRESHOLD_RED_ZONE) as usize;
312 self.estimated_len() + len > adjusted_threshold
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319
320 struct MockCompressor {
321 current_uncompressed_len: u64,
322 total_uncompressed_len: usize,
323 compressed_len: u64,
324 }
325
326 impl MockCompressor {
327 fn new() -> Self {
328 MockCompressor {
329 current_uncompressed_len: 0,
330 total_uncompressed_len: 0,
331 compressed_len: 0,
332 }
333 }
334
335 fn write(&mut self, n: usize) {
336 self.current_uncompressed_len += n as u64;
337 self.total_uncompressed_len += n;
338 }
339
340 fn flush(&mut self, compression_ratio: f64) {
341 self.compressed_len += (self.current_uncompressed_len as f64 * compression_ratio) as u64;
342 self.current_uncompressed_len = 0;
343 }
344
345 fn total_uncompressed_len(&self) -> usize {
346 self.total_uncompressed_len
347 }
348 }
349
350 impl WriteStatistics for MockCompressor {
351 fn total_written(&self) -> u64 {
352 self.compressed_len
353 }
354 }
355
356 #[test]
357 fn compression_estimator_no_output() {
358 let estimator = CompressionEstimator::default();
359
360 assert!(!estimator.would_write_exceed_threshold(10, 100));
365 assert!(!estimator.would_write_exceed_threshold(100, 90));
366 }
367
368 #[test]
369 fn compression_estimator_single_flush() {
370 const MAX_COMPRESSED_LEN: usize = 100;
371 const COMPRESSION_RATIO: f64 = 0.7;
372 const WRITE_LEN: usize = 50;
373
374 let mut estimator = CompressionEstimator::default();
375
376 let mut compressor = MockCompressor::new();
378 assert!(!estimator.would_write_exceed_threshold(WRITE_LEN, MAX_COMPRESSED_LEN));
379
380 compressor.write(WRITE_LEN);
382 compressor.flush(COMPRESSION_RATIO);
383 assert_eq!(compressor.total_written(), 35);
384
385 estimator.track_write(&compressor, WRITE_LEN);
386
387 assert!(estimator.would_write_exceed_threshold(100, MAX_COMPRESSED_LEN));
390
391 assert!(!estimator.would_write_exceed_threshold(WRITE_LEN, MAX_COMPRESSED_LEN));
394 }
395
396 #[test]
397 fn compression_estimator_multiple_flush_partial() {
398 const MAX_COMPRESSED_LEN: usize = 5000;
399 const FIRST_COMPRESSION_RATIO: f64 = 0.7;
400 const FIRST_WRITE_LEN: usize = 5000;
401 const SECOND_COMPRESSION_RATIO: f64 = 2.1;
402 const SECOND_WRITE_LEN: usize = 300;
403 const THIRD_WRITE_LEN: usize = 820;
404
405 let mut estimator = CompressionEstimator::default();
406
407 let mut compressor = MockCompressor::new();
409 assert!(!estimator.would_write_exceed_threshold(FIRST_WRITE_LEN, MAX_COMPRESSED_LEN));
410
411 compressor.write(FIRST_WRITE_LEN);
413 compressor.flush(FIRST_COMPRESSION_RATIO);
414 assert_eq!(compressor.total_uncompressed_len(), FIRST_WRITE_LEN);
415 assert_eq!(compressor.total_written(), 3500);
416
417 estimator.track_write(&compressor, FIRST_WRITE_LEN);
418
419 compressor.write(SECOND_WRITE_LEN);
428 compressor.flush(SECOND_COMPRESSION_RATIO);
429 assert_eq!(compressor.total_uncompressed_len(), FIRST_WRITE_LEN + SECOND_WRITE_LEN);
430 assert_eq!(compressor.total_written(), 4130);
431
432 estimator.track_write(&compressor, SECOND_WRITE_LEN);
433
434 assert!(!estimator.would_write_exceed_threshold(THIRD_WRITE_LEN, MAX_COMPRESSED_LEN));
440 }
441}