saluki_common/buf/
chunked.rs

1use std::{
2    collections::VecDeque,
3    convert::Infallible,
4    io,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use bytes::{buf::UninitSlice, Buf, BufMut, Bytes, BytesMut};
10use http_body::{Body, Frame, SizeHint};
11use tokio::io::AsyncWrite;
12
13/// A bytes buffer that write dynamically-sized payloads across multiple fixed-size chunks.
14///
15/// As callers write data to `ChunkedBytesBuffer`, it will allocate additional chunks as needed, and write the data
16/// across these chunks. This allows for predictable memory usage by avoiding reallocations that overestimate the
17/// necessary additional capacity, and provides mechnical sympathy to the allocator by using consistently-sized chunks.
18///
19/// `ChunkedBytesBuffer` implements [`AsyncWrite`] and [`Body`], allowing it to be asynchronously written to and used as
20/// the body of an HTTP request without any additional allocations and copying/merging of data into a single buffer.
21pub struct ChunkedBytesBuffer {
22    chunks: VecDeque<BytesMut>,
23    chunk_size: usize,
24    remaining_capacity: usize,
25}
26
27impl ChunkedBytesBuffer {
28    /// Creates a new `ChunkedBytesBuffer`, configured to use chunks of the given size.
29    pub fn new(chunk_size: usize) -> Self {
30        Self {
31            chunks: VecDeque::new(),
32            chunk_size,
33            remaining_capacity: 0,
34        }
35    }
36
37    /// Returns `true` if the buffer has no data.
38    pub fn is_empty(&self) -> bool {
39        self.chunks.is_empty()
40    }
41
42    /// Returns the number of bytes written to the buffer.
43    pub fn len(&self) -> usize {
44        self.chunks.iter().map(|chunk| chunk.remaining()).sum()
45    }
46
47    fn register_new_chunk(&mut self) {
48        self.remaining_capacity += self.chunk_size;
49        self.chunks.push_back(BytesMut::with_capacity(self.chunk_size));
50    }
51
52    fn ensure_capacity_for_write(&mut self) {
53        if self.remaining_capacity == 0 {
54            self.register_new_chunk();
55        }
56    }
57
58    /// Consumes this buffer and returns a read-only version of it.
59    ///
60    /// All existing chunks at the time of calling this method will be present in the read-only buffer.
61    pub fn freeze(self) -> FrozenChunkedBytesBuffer {
62        FrozenChunkedBytesBuffer {
63            chunks: self.chunks.into_iter().map(|chunk| chunk.freeze()).collect(),
64        }
65    }
66}
67
68impl AsyncWrite for ChunkedBytesBuffer {
69    fn poll_write(mut self: Pin<&mut Self>, _: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
70        self.put_slice(buf);
71        Poll::Ready(Ok(buf.len()))
72    }
73
74    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
75        Poll::Ready(Ok(()))
76    }
77
78    fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
79        Poll::Ready(Ok(()))
80    }
81}
82
83unsafe impl BufMut for ChunkedBytesBuffer {
84    fn remaining_mut(&self) -> usize {
85        usize::MAX - self.len()
86    }
87
88    fn chunk_mut(&mut self) -> &mut UninitSlice {
89        self.ensure_capacity_for_write();
90        self.chunks.back_mut().unwrap().spare_capacity_mut().into()
91    }
92
93    unsafe fn advance_mut(&mut self, cnt: usize) {
94        self.chunks.back_mut().unwrap().advance_mut(cnt);
95        self.remaining_capacity -= cnt;
96    }
97}
98
99/// A frozen, read-only version of [`ChunkedBytesBuffer`].
100///
101/// `FrozenChunkedBytesBuffer` can be cheaply cloned, and allows for sharing the underlying chunks among multiple tasks.
102#[derive(Clone)]
103pub struct FrozenChunkedBytesBuffer {
104    chunks: VecDeque<Bytes>,
105}
106
107impl FrozenChunkedBytesBuffer {
108    /// Returns `true` if the buffer has no data.
109    pub fn is_empty(&self) -> bool {
110        self.chunks.is_empty()
111    }
112
113    /// Returns the number of bytes written to the buffer.
114    pub fn len(&self) -> usize {
115        self.chunks.iter().map(|chunk| chunk.len()).sum()
116    }
117}
118
119impl Buf for FrozenChunkedBytesBuffer {
120    fn remaining(&self) -> usize {
121        self.chunks.iter().map(|chunk| chunk.remaining()).sum()
122    }
123
124    fn chunk(&self) -> &[u8] {
125        self.chunks.front().map_or(&[], |chunk| chunk.chunk())
126    }
127
128    fn advance(&mut self, mut cnt: usize) {
129        while cnt > 0 {
130            let chunk = self.chunks.front_mut().expect("no chunks left");
131            let chunk_remaining = chunk.remaining();
132            if cnt < chunk_remaining {
133                chunk.advance(cnt);
134                break;
135            }
136
137            chunk.advance(chunk_remaining);
138            cnt -= chunk_remaining;
139            self.chunks.pop_front();
140        }
141    }
142}
143
144impl Body for FrozenChunkedBytesBuffer {
145    type Data = Bytes;
146    type Error = Infallible;
147
148    fn poll_frame(
149        mut self: Pin<&mut Self>, _: &mut Context<'_>,
150    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
151        Poll::Ready(self.chunks.pop_front().map(|chunk| Ok(Frame::data(chunk))))
152    }
153
154    fn size_hint(&self) -> SizeHint {
155        SizeHint::with_exact(self.len() as u64)
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use bytes::BytesMut;
162    use http_body_util::BodyExt as _;
163    use tokio::io::AsyncWriteExt as _;
164    use tokio_test::{assert_ready, task::spawn as test_spawn};
165
166    use super::*;
167
168    const TEST_CHUNK_SIZE: usize = 16;
169    const TEST_BUF_CHUNK_SIZED: &[u8] = b"hello world!!!!!";
170    const TEST_BUF_LESS_THAN_CHUNK_SIZED: &[u8] = b"hello world!";
171    const TEST_BUF_GREATER_THAN_CHUNK_SIZED: &[u8] = b"hello world, here i come!";
172
173    #[test]
174    fn single_write_fits_within_single_chunk() {
175        let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
176
177        // Fits within a single buffer, so it should complete without blocking.i
178        let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_LESS_THAN_CHUNK_SIZED));
179        let result = assert_ready!(fut.poll());
180
181        let n = result.unwrap();
182        assert_eq!(n, TEST_BUF_LESS_THAN_CHUNK_SIZED.len());
183        assert_eq!(chunked_buffer.chunks.len(), 1);
184
185        let total_capacity = chunked_buffer.chunks.len() * TEST_CHUNK_SIZE;
186        assert_eq!(chunked_buffer.remaining_capacity, total_capacity - n);
187    }
188
189    #[test]
190    fn single_write_fits_single_chunk_exactly() {
191        let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
192
193        // Fits within a single buffer, so it should complete without blocking.
194        let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_CHUNK_SIZED));
195        let result = assert_ready!(fut.poll());
196
197        let n = result.unwrap();
198        assert_eq!(n, TEST_BUF_CHUNK_SIZED.len());
199        assert_eq!(chunked_buffer.chunks.len(), 1);
200
201        let total_capacity = chunked_buffer.chunks.len() * TEST_CHUNK_SIZE;
202        assert_eq!(chunked_buffer.remaining_capacity, total_capacity - n);
203    }
204
205    #[test]
206    fn single_write_strides_two_chunks() {
207        let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
208
209        // This won't fit in a single chunk, but should fit within two.
210        let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_GREATER_THAN_CHUNK_SIZED));
211        let result = assert_ready!(fut.poll());
212
213        let n = result.unwrap();
214        assert_eq!(n, TEST_BUF_GREATER_THAN_CHUNK_SIZED.len());
215        assert_eq!(chunked_buffer.chunks.len(), 2);
216
217        let total_capacity = chunked_buffer.chunks.len() * TEST_CHUNK_SIZE;
218        assert_eq!(chunked_buffer.remaining_capacity, total_capacity - n);
219    }
220
221    #[test]
222    fn two_writes_fit_two_chunks_exactly() {
223        let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
224
225        // First write acquires one chunk, and fills it up entirely.
226        let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_CHUNK_SIZED));
227        let result = assert_ready!(fut.poll());
228
229        let first_n = result.unwrap();
230        assert_eq!(first_n, TEST_BUF_CHUNK_SIZED.len());
231        assert_eq!(chunked_buffer.chunks.len(), 1);
232        assert_eq!(chunked_buffer.remaining_capacity, 0);
233
234        // Second write acquires an additional chunk, and also fills it up entirely.
235        let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_CHUNK_SIZED));
236        let result = assert_ready!(fut.poll());
237
238        let second_n = result.unwrap();
239        assert_eq!(second_n, TEST_BUF_CHUNK_SIZED.len());
240        assert_eq!(chunked_buffer.chunks.len(), 2);
241        assert_eq!(chunked_buffer.remaining_capacity, 0);
242    }
243
244    #[tokio::test]
245    async fn all_chunks_returned_as_body() {
246        let test_bufs = &[
247            TEST_BUF_LESS_THAN_CHUNK_SIZED,
248            TEST_BUF_CHUNK_SIZED,
249            TEST_BUF_GREATER_THAN_CHUNK_SIZED,
250        ];
251        let test_bufs_total_len = test_bufs.iter().map(|buf| buf.len()).sum::<usize>();
252        let required_chunks = test_bufs_total_len / TEST_CHUNK_SIZE
253            + if test_bufs_total_len % TEST_CHUNK_SIZE > 0 {
254                1
255            } else {
256                0
257            };
258
259        let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
260        let total_capacity = required_chunks * TEST_CHUNK_SIZE;
261
262        // Do three writes, using the less than/exactly/greater than-sized test buffers.
263        //
264        // We'll write these buffers, concatenated, to a single buffer that we'll use at the end to
265        // compare the collected `Body`-based output.
266        let mut expected_aggregated_body = BytesMut::new();
267        let test_bufs = &[
268            TEST_BUF_LESS_THAN_CHUNK_SIZED,
269            TEST_BUF_CHUNK_SIZED,
270            TEST_BUF_GREATER_THAN_CHUNK_SIZED,
271        ];
272        let mut total_written = 0;
273        for test_buf in test_bufs {
274            chunked_buffer.write_all(test_buf).await.unwrap();
275            expected_aggregated_body.put(*test_buf);
276            total_written += test_buf.len();
277        }
278
279        assert_eq!(test_bufs_total_len, total_written);
280        assert_eq!(chunked_buffer.chunks.len(), required_chunks);
281        assert_eq!(chunked_buffer.remaining_capacity, total_capacity - total_written);
282
283        let read_chunked_buffer = chunked_buffer.freeze();
284
285        // We should now be able to collect the chunked buffer as a `Body`, into a single output buffer.
286        let actual_aggregated_body = read_chunked_buffer.collect().await.expect("cannot fail").to_bytes();
287
288        assert_eq!(expected_aggregated_body.freeze(), actual_aggregated_body);
289    }
290}