1use std::{
2 io,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use async_compression::{
8 tokio::write::{ZlibEncoder, ZstdEncoder},
9 Level,
10};
11use http::HeaderValue;
12use pin_project::pin_project;
13use tokio::io::AsyncWrite;
14use tracing::trace;
15
16const THRESHOLD_RED_ZONE: f64 = 0.99;
20
21static CONTENT_ENCODING_DEFLATE: HeaderValue = HeaderValue::from_static("deflate");
22static CONTENT_ENCODING_ZSTD: HeaderValue = HeaderValue::from_static("zstd");
23
24#[derive(Copy, Clone, Debug)]
26pub enum CompressionScheme {
27 Noop,
29 Zlib(Level),
31 Zstd(Level),
33}
34
35impl CompressionScheme {
36 pub const fn noop() -> Self {
38 Self::Noop
39 }
40
41 pub const fn zlib_default() -> Self {
43 Self::Zlib(Level::Default)
44 }
45
46 pub const fn zstd_default() -> Self {
48 Self::Zstd(Level::Default)
49 }
50
51 pub fn new(scheme: &str, level: i32) -> Self {
57 match scheme {
58 "zlib" => CompressionScheme::zlib_default(),
59 "zstd" => Self::Zstd(Level::Precise(level)),
60 _ => Self::Zstd(Level::Default),
61 }
62 }
63}
64
65#[pin_project]
66pub struct CountingWriter<W> {
67 #[pin]
68 inner: W,
69 total_written: u64,
70}
71
72impl<W> CountingWriter<W> {
73 fn new(inner: W) -> Self {
74 Self {
75 inner,
76 total_written: 0,
77 }
78 }
79
80 fn total_written(&self) -> u64 {
81 self.total_written
82 }
83
84 fn into_inner(self) -> W {
85 self.inner
86 }
87}
88
89pub trait WriteStatistics {
91 fn total_written(&self) -> u64;
93}
94
95impl<W: AsyncWrite> AsyncWrite for CountingWriter<W> {
96 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
97 let mut this = self.project();
98 this.inner.as_mut().poll_write(cx, buf).map(|result| {
99 if let Ok(written) = &result {
100 *this.total_written += *written as u64;
101 }
102
103 result
104 })
105 }
106
107 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
108 self.project().inner.poll_flush(cx)
109 }
110
111 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
112 self.project().inner.poll_shutdown(cx)
113 }
114}
115
116#[pin_project(project = CompressorProjected)]
121pub enum Compressor<W: AsyncWrite> {
122 Noop(#[pin] CountingWriter<W>),
124 Zlib(#[pin] ZlibEncoder<W>),
126 Zstd(#[pin] ZstdEncoder<CountingWriter<W>>),
128}
129
130impl<W: AsyncWrite> Compressor<W> {
131 pub fn from_scheme(scheme: CompressionScheme, writer: W) -> Self {
133 match scheme {
134 CompressionScheme::Noop => Self::Noop(CountingWriter::new(writer)),
135 CompressionScheme::Zlib(level) => Self::Zlib(ZlibEncoder::with_quality(writer, level)),
136 CompressionScheme::Zstd(level) => Self::Zstd(ZstdEncoder::with_quality(CountingWriter::new(writer), level)),
137 }
138 }
139
140 pub fn into_inner(self) -> W {
142 match self {
143 Self::Noop(encoder) => encoder.into_inner(),
144 Self::Zlib(encoder) => encoder.into_inner(),
145 Self::Zstd(encoder) => encoder.into_inner().into_inner(),
146 }
147 }
148
149 pub fn content_encoding(&self) -> Option<HeaderValue> {
151 match self {
152 Self::Noop(_) => None,
153 Self::Zlib(_) => Some(CONTENT_ENCODING_DEFLATE.clone()),
154 Self::Zstd(_) => Some(CONTENT_ENCODING_ZSTD.clone()),
155 }
156 }
157}
158
159impl<W: AsyncWrite> AsyncWrite for Compressor<W> {
160 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
161 match self.project() {
162 CompressorProjected::Noop(encoder) => encoder.poll_write(cx, buf),
163 CompressorProjected::Zlib(encoder) => encoder.poll_write(cx, buf),
164 CompressorProjected::Zstd(encoder) => encoder.poll_write(cx, buf),
165 }
166 }
167
168 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
169 match self.project() {
170 CompressorProjected::Noop(encoder) => encoder.poll_flush(cx),
171 CompressorProjected::Zlib(encoder) => encoder.poll_flush(cx),
172 CompressorProjected::Zstd(encoder) => encoder.poll_flush(cx),
173 }
174 }
175
176 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
177 match self.project() {
178 CompressorProjected::Noop(encoder) => encoder.poll_shutdown(cx),
179 CompressorProjected::Zlib(encoder) => encoder.poll_shutdown(cx),
180 CompressorProjected::Zstd(encoder) => encoder.poll_shutdown(cx),
181 }
182 }
183}
184
185impl<W: AsyncWrite> WriteStatistics for Compressor<W> {
186 fn total_written(&self) -> u64 {
187 match self {
188 Compressor::Noop(encoder) => encoder.total_written(),
189 Compressor::Zlib(encoder) => encoder.total_out(),
190 Compressor::Zstd(encoder) => encoder.get_ref().total_written(),
191 }
192 }
193}
194
195#[derive(Debug, Default)]
220pub struct CompressionEstimator {
221 in_flight_uncompressed_len: usize,
222 total_uncompressed_len: usize,
223 total_compressed_len: u64,
224 current_compression_ratio: f64,
225}
226
227impl CompressionEstimator {
228 pub fn track_write<W>(&mut self, compressor: &W, uncompressed_len: usize)
230 where
231 W: WriteStatistics,
232 {
233 self.in_flight_uncompressed_len += uncompressed_len;
234 self.total_uncompressed_len += uncompressed_len;
235
236 let compressed_len = compressor.total_written();
237 let compressed_len_delta = (compressed_len - self.total_compressed_len) as usize;
238 if compressed_len_delta > 0 {
239 self.current_compression_ratio = compressed_len as f64 / self.total_uncompressed_len as f64;
241 self.total_compressed_len = compressed_len;
242 self.in_flight_uncompressed_len = 0;
243
244 trace!(
245 block_size = compressed_len_delta,
246 uncompressed_len = self.total_uncompressed_len,
247 compressed_len = self.total_compressed_len,
248 compression_ratio = self.current_compression_ratio,
249 "Compressor wrote block to output stream."
250 );
251 }
252 }
253
254 pub fn reset(&mut self) {
256 self.in_flight_uncompressed_len = 0;
257 self.total_uncompressed_len = 0;
258 self.total_compressed_len = 0;
259 self.current_compression_ratio = 0.0;
260 }
261
262 pub fn estimated_len(&self) -> usize {
268 let estimated_in_flight_compressed_len =
269 (self.in_flight_uncompressed_len as f64 * self.current_compression_ratio) as usize;
270
271 self.total_compressed_len as usize + estimated_in_flight_compressed_len
272 }
273
274 pub fn would_write_exceed_threshold(&self, len: usize, threshold: usize) -> bool {
277 if len > threshold {
279 return true;
280 }
281
282 if self.total_compressed_len == 0 {
286 return false;
287 }
288
289 let adjusted_threshold = (threshold as f64 * THRESHOLD_RED_ZONE) as usize;
299 self.estimated_len() + len > adjusted_threshold
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 struct MockCompressor {
308 current_uncompressed_len: u64,
309 total_uncompressed_len: usize,
310 compressed_len: u64,
311 }
312
313 impl MockCompressor {
314 fn new() -> Self {
315 MockCompressor {
316 current_uncompressed_len: 0,
317 total_uncompressed_len: 0,
318 compressed_len: 0,
319 }
320 }
321
322 fn write(&mut self, n: usize) {
323 self.current_uncompressed_len += n as u64;
324 self.total_uncompressed_len += n;
325 }
326
327 fn flush(&mut self, compression_ratio: f64) {
328 self.compressed_len += (self.current_uncompressed_len as f64 * compression_ratio) as u64;
329 self.current_uncompressed_len = 0;
330 }
331
332 fn total_uncompressed_len(&self) -> usize {
333 self.total_uncompressed_len
334 }
335 }
336
337 impl WriteStatistics for MockCompressor {
338 fn total_written(&self) -> u64 {
339 self.compressed_len
340 }
341 }
342
343 #[test]
344 fn compression_estimator_no_output() {
345 let estimator = CompressionEstimator::default();
346
347 assert!(!estimator.would_write_exceed_threshold(10, 100));
348 assert!(estimator.would_write_exceed_threshold(100, 90));
349 }
350
351 #[test]
352 fn compression_estimator_single_flush() {
353 const MAX_COMPRESSED_LEN: usize = 100;
354 const COMPRESSION_RATIO: f64 = 0.7;
355 const WRITE_LEN: usize = 50;
356
357 let mut estimator = CompressionEstimator::default();
358
359 let mut compressor = MockCompressor::new();
361 assert!(!estimator.would_write_exceed_threshold(WRITE_LEN, MAX_COMPRESSED_LEN));
362
363 compressor.write(WRITE_LEN);
365 compressor.flush(COMPRESSION_RATIO);
366 assert_eq!(compressor.total_written(), 35);
367
368 estimator.track_write(&compressor, WRITE_LEN);
369
370 assert!(estimator.would_write_exceed_threshold(100, MAX_COMPRESSED_LEN));
373
374 assert!(!estimator.would_write_exceed_threshold(WRITE_LEN, MAX_COMPRESSED_LEN));
377 }
378
379 #[test]
380 fn compression_estimator_multiple_flush_partial() {
381 const MAX_COMPRESSED_LEN: usize = 5000;
382 const FIRST_COMPRESSION_RATIO: f64 = 0.7;
383 const FIRST_WRITE_LEN: usize = 5000;
384 const SECOND_COMPRESSION_RATIO: f64 = 2.1;
385 const SECOND_WRITE_LEN: usize = 300;
386 const THIRD_WRITE_LEN: usize = 820;
387
388 let mut estimator = CompressionEstimator::default();
389
390 let mut compressor = MockCompressor::new();
392 assert!(!estimator.would_write_exceed_threshold(FIRST_WRITE_LEN, MAX_COMPRESSED_LEN));
393
394 compressor.write(FIRST_WRITE_LEN);
396 compressor.flush(FIRST_COMPRESSION_RATIO);
397 assert_eq!(compressor.total_uncompressed_len(), FIRST_WRITE_LEN);
398 assert_eq!(compressor.total_written(), 3500);
399
400 estimator.track_write(&compressor, FIRST_WRITE_LEN);
401
402 compressor.write(SECOND_WRITE_LEN);
411 compressor.flush(SECOND_COMPRESSION_RATIO);
412 assert_eq!(compressor.total_uncompressed_len(), FIRST_WRITE_LEN + SECOND_WRITE_LEN);
413 assert_eq!(compressor.total_written(), 4130);
414
415 estimator.track_write(&compressor, SECOND_WRITE_LEN);
416
417 assert!(!estimator.would_write_exceed_threshold(THIRD_WRITE_LEN, MAX_COMPRESSED_LEN));
423 }
424}