saluki_common/buf/
chunked.rs

1use std::{
2    collections::VecDeque,
3    convert::Infallible,
4    io,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use bytes::{buf::UninitSlice, Buf, BufMut, Bytes, BytesMut};
10use http_body::{Body, Frame, SizeHint};
11use tokio::io::AsyncWrite;
12
13/// A bytes buffer that write dynamically-sized payloads across multiple fixed-size chunks.
14///
15/// As callers write data to `ChunkedBytesBuffer`, it will allocate additional chunks as needed, and write the data
16/// across these chunks. This allows for predictable memory usage by avoiding reallocations that overestimate the
17/// necessary additional capacity, and provides mechnical sympathy to the allocator by using consistently-sized chunks.
18///
19/// `ChunkedBytesBuffer` implements [`AsyncWrite`] and [`Body`], allowing it to be asynchronously written to and used as
20/// the body of an HTTP request without any additional allocations and copying/merging of data into a single buffer.
21pub struct ChunkedBytesBuffer {
22    chunks: VecDeque<BytesMut>,
23    chunk_size: usize,
24    remaining_capacity: usize,
25}
26
27impl ChunkedBytesBuffer {
28    /// Creates a new `ChunkedBytesBuffer`, configured to use chunks of the given size.
29    pub fn new(chunk_size: usize) -> Self {
30        Self {
31            chunks: VecDeque::new(),
32            chunk_size,
33            remaining_capacity: 0,
34        }
35    }
36
37    /// Returns `true` if the buffer has no data.
38    pub fn is_empty(&self) -> bool {
39        self.chunks.is_empty()
40    }
41
42    /// Returns the number of bytes written to the buffer.
43    pub fn len(&self) -> usize {
44        self.chunks.iter().map(|chunk| chunk.remaining()).sum()
45    }
46
47    fn register_new_chunk(&mut self) {
48        self.remaining_capacity += self.chunk_size;
49        self.chunks.push_back(BytesMut::with_capacity(self.chunk_size));
50    }
51
52    fn ensure_capacity_for_write(&mut self) {
53        if self.remaining_capacity == 0 {
54            self.register_new_chunk();
55        }
56    }
57
58    /// Consumes this buffer and returns a read-only version of it.
59    ///
60    /// All existing chunks at the time of calling this method will be present in the read-only buffer.
61    pub fn freeze(self) -> FrozenChunkedBytesBuffer {
62        FrozenChunkedBytesBuffer {
63            chunks: self.chunks.into_iter().map(|chunk| chunk.freeze()).collect(),
64        }
65    }
66}
67
68impl AsyncWrite for ChunkedBytesBuffer {
69    fn poll_write(mut self: Pin<&mut Self>, _: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
70        self.put_slice(buf);
71        Poll::Ready(Ok(buf.len()))
72    }
73
74    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
75        Poll::Ready(Ok(()))
76    }
77
78    fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
79        Poll::Ready(Ok(()))
80    }
81}
82
83unsafe impl BufMut for ChunkedBytesBuffer {
84    fn remaining_mut(&self) -> usize {
85        usize::MAX - self.len()
86    }
87
88    fn chunk_mut(&mut self) -> &mut UninitSlice {
89        self.ensure_capacity_for_write();
90        self.chunks.back_mut().unwrap().spare_capacity_mut().into()
91    }
92
93    unsafe fn advance_mut(&mut self, cnt: usize) {
94        self.chunks.back_mut().unwrap().advance_mut(cnt);
95        self.remaining_capacity -= cnt;
96    }
97}
98
99/// A frozen, read-only version of [`ChunkedBytesBuffer`].
100///
101/// `FrozenChunkedBytesBuffer` can be cheaply cloned, and allows for sharing the underlying chunks among multiple tasks.
102#[derive(Clone)]
103pub struct FrozenChunkedBytesBuffer {
104    chunks: VecDeque<Bytes>,
105}
106
107impl FrozenChunkedBytesBuffer {
108    /// Returns `true` if the buffer has no data.
109    pub fn is_empty(&self) -> bool {
110        self.chunks.is_empty()
111    }
112
113    /// Returns the number of bytes written to the buffer.
114    pub fn len(&self) -> usize {
115        self.chunks.iter().map(|chunk| chunk.len()).sum()
116    }
117}
118
119impl From<Bytes> for FrozenChunkedBytesBuffer {
120    fn from(bytes: Bytes) -> Self {
121        let mut chunks = VecDeque::new();
122        if !bytes.is_empty() {
123            chunks.push_back(bytes);
124        }
125        Self { chunks }
126    }
127}
128
129impl Buf for FrozenChunkedBytesBuffer {
130    fn remaining(&self) -> usize {
131        self.chunks.iter().map(|chunk| chunk.remaining()).sum()
132    }
133
134    fn chunk(&self) -> &[u8] {
135        self.chunks.front().map_or(&[], |chunk| chunk.chunk())
136    }
137
138    fn advance(&mut self, mut cnt: usize) {
139        while cnt > 0 {
140            let chunk = self.chunks.front_mut().expect("no chunks left");
141            let chunk_remaining = chunk.remaining();
142            if cnt < chunk_remaining {
143                chunk.advance(cnt);
144                break;
145            }
146
147            chunk.advance(chunk_remaining);
148            cnt -= chunk_remaining;
149            self.chunks.pop_front();
150        }
151    }
152}
153
154impl Body for FrozenChunkedBytesBuffer {
155    type Data = Bytes;
156    type Error = Infallible;
157
158    fn poll_frame(
159        mut self: Pin<&mut Self>, _: &mut Context<'_>,
160    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
161        Poll::Ready(self.chunks.pop_front().map(|chunk| Ok(Frame::data(chunk))))
162    }
163
164    fn size_hint(&self) -> SizeHint {
165        SizeHint::with_exact(self.len() as u64)
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use bytes::BytesMut;
172    use http_body_util::BodyExt as _;
173    use tokio::io::AsyncWriteExt as _;
174    use tokio_test::{assert_ready, task::spawn as test_spawn};
175
176    use super::*;
177
178    const TEST_CHUNK_SIZE: usize = 16;
179    const TEST_BUF_CHUNK_SIZED: &[u8] = b"hello world!!!!!";
180    const TEST_BUF_LESS_THAN_CHUNK_SIZED: &[u8] = b"hello world!";
181    const TEST_BUF_GREATER_THAN_CHUNK_SIZED: &[u8] = b"hello world, here i come!";
182
183    #[test]
184    fn single_write_fits_within_single_chunk() {
185        let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
186
187        // Fits within a single buffer, so it should complete without blocking.i
188        let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_LESS_THAN_CHUNK_SIZED));
189        let result = assert_ready!(fut.poll());
190
191        let n = result.unwrap();
192        assert_eq!(n, TEST_BUF_LESS_THAN_CHUNK_SIZED.len());
193        assert_eq!(chunked_buffer.chunks.len(), 1);
194
195        let total_capacity = chunked_buffer.chunks.len() * TEST_CHUNK_SIZE;
196        assert_eq!(chunked_buffer.remaining_capacity, total_capacity - n);
197    }
198
199    #[test]
200    fn single_write_fits_single_chunk_exactly() {
201        let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
202
203        // Fits within a single buffer, so it should complete without blocking.
204        let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_CHUNK_SIZED));
205        let result = assert_ready!(fut.poll());
206
207        let n = result.unwrap();
208        assert_eq!(n, TEST_BUF_CHUNK_SIZED.len());
209        assert_eq!(chunked_buffer.chunks.len(), 1);
210
211        let total_capacity = chunked_buffer.chunks.len() * TEST_CHUNK_SIZE;
212        assert_eq!(chunked_buffer.remaining_capacity, total_capacity - n);
213    }
214
215    #[test]
216    fn single_write_strides_two_chunks() {
217        let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
218
219        // This won't fit in a single chunk, but should fit within two.
220        let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_GREATER_THAN_CHUNK_SIZED));
221        let result = assert_ready!(fut.poll());
222
223        let n = result.unwrap();
224        assert_eq!(n, TEST_BUF_GREATER_THAN_CHUNK_SIZED.len());
225        assert_eq!(chunked_buffer.chunks.len(), 2);
226
227        let total_capacity = chunked_buffer.chunks.len() * TEST_CHUNK_SIZE;
228        assert_eq!(chunked_buffer.remaining_capacity, total_capacity - n);
229    }
230
231    #[test]
232    fn two_writes_fit_two_chunks_exactly() {
233        let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
234
235        // First write acquires one chunk, and fills it up entirely.
236        let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_CHUNK_SIZED));
237        let result = assert_ready!(fut.poll());
238
239        let first_n = result.unwrap();
240        assert_eq!(first_n, TEST_BUF_CHUNK_SIZED.len());
241        assert_eq!(chunked_buffer.chunks.len(), 1);
242        assert_eq!(chunked_buffer.remaining_capacity, 0);
243
244        // Second write acquires an additional chunk, and also fills it up entirely.
245        let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_CHUNK_SIZED));
246        let result = assert_ready!(fut.poll());
247
248        let second_n = result.unwrap();
249        assert_eq!(second_n, TEST_BUF_CHUNK_SIZED.len());
250        assert_eq!(chunked_buffer.chunks.len(), 2);
251        assert_eq!(chunked_buffer.remaining_capacity, 0);
252    }
253
254    #[tokio::test]
255    async fn all_chunks_returned_as_body() {
256        let test_bufs = &[
257            TEST_BUF_LESS_THAN_CHUNK_SIZED,
258            TEST_BUF_CHUNK_SIZED,
259            TEST_BUF_GREATER_THAN_CHUNK_SIZED,
260        ];
261        let test_bufs_total_len = test_bufs.iter().map(|buf| buf.len()).sum::<usize>();
262        let required_chunks = test_bufs_total_len / TEST_CHUNK_SIZE
263            + if test_bufs_total_len % TEST_CHUNK_SIZE > 0 {
264                1
265            } else {
266                0
267            };
268
269        let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
270        let total_capacity = required_chunks * TEST_CHUNK_SIZE;
271
272        // Do three writes, using the less than/exactly/greater than-sized test buffers.
273        //
274        // We'll write these buffers, concatenated, to a single buffer that we'll use at the end to
275        // compare the collected `Body`-based output.
276        let mut expected_aggregated_body = BytesMut::new();
277        let test_bufs = &[
278            TEST_BUF_LESS_THAN_CHUNK_SIZED,
279            TEST_BUF_CHUNK_SIZED,
280            TEST_BUF_GREATER_THAN_CHUNK_SIZED,
281        ];
282        let mut total_written = 0;
283        for test_buf in test_bufs {
284            chunked_buffer.write_all(test_buf).await.unwrap();
285            expected_aggregated_body.put(*test_buf);
286            total_written += test_buf.len();
287        }
288
289        assert_eq!(test_bufs_total_len, total_written);
290        assert_eq!(chunked_buffer.chunks.len(), required_chunks);
291        assert_eq!(chunked_buffer.remaining_capacity, total_capacity - total_written);
292
293        let read_chunked_buffer = chunked_buffer.freeze();
294
295        // We should now be able to collect the chunked buffer as a `Body`, into a single output buffer.
296        let actual_aggregated_body = read_chunked_buffer.collect().await.expect("cannot fail").to_bytes();
297
298        assert_eq!(expected_aggregated_body.freeze(), actual_aggregated_body);
299    }
300}