saluki_common/buf/
chunked.rs

1use std::{
2    collections::VecDeque,
3    convert::Infallible,
4    io,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use bytes::{buf::UninitSlice, Buf, BufMut, Bytes, BytesMut};
10use http_body::{Body, Frame, SizeHint};
11use tokio::io::AsyncWrite;
12
13/// A bytes buffer that write dynamically-sized payloads across multiple fixed-size chunks.
14///
15/// As callers write data to `ChunkedBytesBuffer`, it will allocate additional chunks as needed, and write the data
16/// across these chunks. This allows for predictable memory usage by avoiding reallocations that overestimate the
17/// necessary additional capacity, and provides mechnical sympathy to the allocator by using consistently-sized chunks.
18///
19/// `ChunkedBytesBuffer` implements [`AsyncWrite`] and [`Body`], allowing it to be asynchronously written to and used as
20/// the body of an HTTP request without any additional allocations and copying/merging of data into a single buffer.
21pub struct ChunkedBytesBuffer {
22    chunks: VecDeque<BytesMut>,
23    chunk_size: usize,
24    remaining_capacity: usize,
25}
26
27impl ChunkedBytesBuffer {
28    /// Creates a new `ChunkedBytesBuffer`, configured to use chunks of the given size.
29    pub fn new(chunk_size: usize) -> Self {
30        Self {
31            chunks: VecDeque::new(),
32            chunk_size,
33            remaining_capacity: 0,
34        }
35    }
36
37    /// Returns `true` if the buffer has no data.
38    pub fn is_empty(&self) -> bool {
39        self.chunks.is_empty()
40    }
41
42    /// Returns the number of bytes written to the buffer.
43    pub fn len(&self) -> usize {
44        self.chunks.iter().map(|chunk| chunk.remaining()).sum()
45    }
46
47    fn register_new_chunk(&mut self) {
48        self.remaining_capacity += self.chunk_size;
49        self.chunks.push_back(BytesMut::with_capacity(self.chunk_size));
50    }
51
52    fn ensure_capacity_for_write(&mut self) {
53        if self.remaining_capacity == 0 {
54            self.register_new_chunk();
55        }
56    }
57
58    /// Consumes this buffer and returns a read-only version of it.
59    ///
60    /// All existing chunks at the time of calling this method will be present in the read-only buffer.
61    pub fn freeze(self) -> FrozenChunkedBytesBuffer {
62        FrozenChunkedBytesBuffer {
63            chunks: self.chunks.into_iter().map(|chunk| chunk.freeze()).collect(),
64        }
65    }
66}
67
68impl AsyncWrite for ChunkedBytesBuffer {
69    fn poll_write(mut self: Pin<&mut Self>, _: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
70        self.put_slice(buf);
71        Poll::Ready(Ok(buf.len()))
72    }
73
74    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
75        Poll::Ready(Ok(()))
76    }
77
78    fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
79        Poll::Ready(Ok(()))
80    }
81}
82
83unsafe impl BufMut for ChunkedBytesBuffer {
84    fn remaining_mut(&self) -> usize {
85        usize::MAX - self.len()
86    }
87
88    fn chunk_mut(&mut self) -> &mut UninitSlice {
89        self.ensure_capacity_for_write();
90        self.chunks.back_mut().unwrap().spare_capacity_mut().into()
91    }
92
93    unsafe fn advance_mut(&mut self, cnt: usize) {
94        self.chunks.back_mut().unwrap().advance_mut(cnt);
95        self.remaining_capacity -= cnt;
96    }
97}
98
99/// A frozen, read-only version of [`ChunkedBytesBuffer`].
100///
101/// `FrozenChunkedBytesBuffer` can be cheaply cloned, and allows for sharing the underlying chunks among multiple tasks.
102#[derive(Clone)]
103pub struct FrozenChunkedBytesBuffer {
104    chunks: VecDeque<Bytes>,
105}
106
107impl FrozenChunkedBytesBuffer {
108    /// Returns `true` if the buffer has no data.
109    pub fn is_empty(&self) -> bool {
110        self.chunks.is_empty()
111    }
112
113    /// Returns the number of bytes written to the buffer.
114    pub fn len(&self) -> usize {
115        self.chunks.iter().map(|chunk| chunk.len()).sum()
116    }
117
118    /// Consumes the frozen buffer and returns a single `Bytes` value representing all chunks.
119    ///
120    /// This method provides an optimized implementation for single chunk buffers to avoid allocations.
121    pub fn into_bytes(mut self) -> Bytes {
122        if self.chunks.len() == 1 {
123            self.chunks.pop_front().unwrap()
124        } else {
125            let mut buf = BytesMut::new();
126            for chunk in self.chunks {
127                buf.extend_from_slice(&chunk);
128            }
129            buf.freeze()
130        }
131    }
132}
133
134impl From<Bytes> for FrozenChunkedBytesBuffer {
135    fn from(bytes: Bytes) -> Self {
136        let mut chunks = VecDeque::new();
137        if !bytes.is_empty() {
138            chunks.push_back(bytes);
139        }
140        Self { chunks }
141    }
142}
143
144impl Buf for FrozenChunkedBytesBuffer {
145    fn remaining(&self) -> usize {
146        self.chunks.iter().map(|chunk| chunk.remaining()).sum()
147    }
148
149    fn chunk(&self) -> &[u8] {
150        self.chunks.front().map_or(&[], |chunk| chunk.chunk())
151    }
152
153    fn advance(&mut self, mut cnt: usize) {
154        while cnt > 0 {
155            let chunk = self.chunks.front_mut().expect("no chunks left");
156            let chunk_remaining = chunk.remaining();
157            if cnt < chunk_remaining {
158                chunk.advance(cnt);
159                break;
160            }
161
162            chunk.advance(chunk_remaining);
163            cnt -= chunk_remaining;
164            self.chunks.pop_front();
165        }
166    }
167}
168
169impl Body for FrozenChunkedBytesBuffer {
170    type Data = Bytes;
171    type Error = Infallible;
172
173    fn poll_frame(
174        mut self: Pin<&mut Self>, _: &mut Context<'_>,
175    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
176        Poll::Ready(self.chunks.pop_front().map(|chunk| Ok(Frame::data(chunk))))
177    }
178
179    fn size_hint(&self) -> SizeHint {
180        SizeHint::with_exact(self.len() as u64)
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use bytes::BytesMut;
187    use http_body_util::BodyExt as _;
188    use tokio::io::AsyncWriteExt as _;
189    use tokio_test::{assert_ready, task::spawn as test_spawn};
190
191    use super::*;
192
193    const TEST_CHUNK_SIZE: usize = 16;
194    const TEST_BUF_CHUNK_SIZED: &[u8] = b"hello world!!!!!";
195    const TEST_BUF_LESS_THAN_CHUNK_SIZED: &[u8] = b"hello world!";
196    const TEST_BUF_GREATER_THAN_CHUNK_SIZED: &[u8] = b"hello world, here i come!";
197
198    #[test]
199    fn single_write_fits_within_single_chunk() {
200        let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
201
202        // Fits within a single buffer, so it should complete without blocking.i
203        let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_LESS_THAN_CHUNK_SIZED));
204        let result = assert_ready!(fut.poll());
205
206        let n = result.unwrap();
207        assert_eq!(n, TEST_BUF_LESS_THAN_CHUNK_SIZED.len());
208        assert_eq!(chunked_buffer.chunks.len(), 1);
209
210        let total_capacity = chunked_buffer.chunks.len() * TEST_CHUNK_SIZE;
211        assert_eq!(chunked_buffer.remaining_capacity, total_capacity - n);
212    }
213
214    #[test]
215    fn single_write_fits_single_chunk_exactly() {
216        let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
217
218        // Fits within a single buffer, so it should complete without blocking.
219        let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_CHUNK_SIZED));
220        let result = assert_ready!(fut.poll());
221
222        let n = result.unwrap();
223        assert_eq!(n, TEST_BUF_CHUNK_SIZED.len());
224        assert_eq!(chunked_buffer.chunks.len(), 1);
225
226        let total_capacity = chunked_buffer.chunks.len() * TEST_CHUNK_SIZE;
227        assert_eq!(chunked_buffer.remaining_capacity, total_capacity - n);
228    }
229
230    #[test]
231    fn single_write_strides_two_chunks() {
232        let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
233
234        // This won't fit in a single chunk, but should fit within two.
235        let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_GREATER_THAN_CHUNK_SIZED));
236        let result = assert_ready!(fut.poll());
237
238        let n = result.unwrap();
239        assert_eq!(n, TEST_BUF_GREATER_THAN_CHUNK_SIZED.len());
240        assert_eq!(chunked_buffer.chunks.len(), 2);
241
242        let total_capacity = chunked_buffer.chunks.len() * TEST_CHUNK_SIZE;
243        assert_eq!(chunked_buffer.remaining_capacity, total_capacity - n);
244    }
245
246    #[test]
247    fn two_writes_fit_two_chunks_exactly() {
248        let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
249
250        // First write acquires one chunk, and fills it up entirely.
251        let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_CHUNK_SIZED));
252        let result = assert_ready!(fut.poll());
253
254        let first_n = result.unwrap();
255        assert_eq!(first_n, TEST_BUF_CHUNK_SIZED.len());
256        assert_eq!(chunked_buffer.chunks.len(), 1);
257        assert_eq!(chunked_buffer.remaining_capacity, 0);
258
259        // Second write acquires an additional chunk, and also fills it up entirely.
260        let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_CHUNK_SIZED));
261        let result = assert_ready!(fut.poll());
262
263        let second_n = result.unwrap();
264        assert_eq!(second_n, TEST_BUF_CHUNK_SIZED.len());
265        assert_eq!(chunked_buffer.chunks.len(), 2);
266        assert_eq!(chunked_buffer.remaining_capacity, 0);
267    }
268
269    #[tokio::test]
270    async fn all_chunks_returned_as_body() {
271        let test_bufs = &[
272            TEST_BUF_LESS_THAN_CHUNK_SIZED,
273            TEST_BUF_CHUNK_SIZED,
274            TEST_BUF_GREATER_THAN_CHUNK_SIZED,
275        ];
276        let test_bufs_total_len = test_bufs.iter().map(|buf| buf.len()).sum::<usize>();
277        let required_chunks = test_bufs_total_len / TEST_CHUNK_SIZE
278            + if test_bufs_total_len % TEST_CHUNK_SIZE > 0 {
279                1
280            } else {
281                0
282            };
283
284        let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
285        let total_capacity = required_chunks * TEST_CHUNK_SIZE;
286
287        // Do three writes, using the less than/exactly/greater than-sized test buffers.
288        //
289        // We'll write these buffers, concatenated, to a single buffer that we'll use at the end to
290        // compare the collected `Body`-based output.
291        let mut expected_aggregated_body = BytesMut::new();
292        let test_bufs = &[
293            TEST_BUF_LESS_THAN_CHUNK_SIZED,
294            TEST_BUF_CHUNK_SIZED,
295            TEST_BUF_GREATER_THAN_CHUNK_SIZED,
296        ];
297        let mut total_written = 0;
298        for test_buf in test_bufs {
299            chunked_buffer.write_all(test_buf).await.unwrap();
300            expected_aggregated_body.put(*test_buf);
301            total_written += test_buf.len();
302        }
303
304        assert_eq!(test_bufs_total_len, total_written);
305        assert_eq!(chunked_buffer.chunks.len(), required_chunks);
306        assert_eq!(chunked_buffer.remaining_capacity, total_capacity - total_written);
307
308        let read_chunked_buffer = chunked_buffer.freeze();
309
310        // We should now be able to collect the chunked buffer as a `Body`, into a single output buffer.
311        let actual_aggregated_body = read_chunked_buffer.collect().await.expect("cannot fail").to_bytes();
312
313        assert_eq!(expected_aggregated_body.freeze(), actual_aggregated_body);
314    }
315}