saluki_common/buf/
chunked.rs1use std::{
2 collections::VecDeque,
3 convert::Infallible,
4 io,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use bytes::{buf::UninitSlice, Buf, BufMut, Bytes, BytesMut};
10use http_body::{Body, Frame, SizeHint};
11use tokio::io::AsyncWrite;
12
13pub struct ChunkedBytesBuffer {
22 chunks: VecDeque<BytesMut>,
23 chunk_size: usize,
24 remaining_capacity: usize,
25}
26
27impl ChunkedBytesBuffer {
28 pub fn new(chunk_size: usize) -> Self {
30 Self {
31 chunks: VecDeque::new(),
32 chunk_size,
33 remaining_capacity: 0,
34 }
35 }
36
37 pub fn is_empty(&self) -> bool {
39 self.chunks.is_empty()
40 }
41
42 pub fn len(&self) -> usize {
44 self.chunks.iter().map(|chunk| chunk.remaining()).sum()
45 }
46
47 fn register_new_chunk(&mut self) {
48 self.remaining_capacity += self.chunk_size;
49 self.chunks.push_back(BytesMut::with_capacity(self.chunk_size));
50 }
51
52 fn ensure_capacity_for_write(&mut self) {
53 if self.remaining_capacity == 0 {
54 self.register_new_chunk();
55 }
56 }
57
58 pub fn freeze(self) -> FrozenChunkedBytesBuffer {
62 FrozenChunkedBytesBuffer {
63 chunks: self.chunks.into_iter().map(|chunk| chunk.freeze()).collect(),
64 }
65 }
66}
67
68impl AsyncWrite for ChunkedBytesBuffer {
69 fn poll_write(mut self: Pin<&mut Self>, _: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
70 self.put_slice(buf);
71 Poll::Ready(Ok(buf.len()))
72 }
73
74 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
75 Poll::Ready(Ok(()))
76 }
77
78 fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
79 Poll::Ready(Ok(()))
80 }
81}
82
83unsafe impl BufMut for ChunkedBytesBuffer {
84 fn remaining_mut(&self) -> usize {
85 usize::MAX - self.len()
86 }
87
88 fn chunk_mut(&mut self) -> &mut UninitSlice {
89 self.ensure_capacity_for_write();
90 self.chunks.back_mut().unwrap().spare_capacity_mut().into()
91 }
92
93 unsafe fn advance_mut(&mut self, cnt: usize) {
94 self.chunks.back_mut().unwrap().advance_mut(cnt);
95 self.remaining_capacity -= cnt;
96 }
97}
98
99#[derive(Clone)]
103pub struct FrozenChunkedBytesBuffer {
104 chunks: VecDeque<Bytes>,
105}
106
107impl FrozenChunkedBytesBuffer {
108 pub fn is_empty(&self) -> bool {
110 self.chunks.is_empty()
111 }
112
113 pub fn len(&self) -> usize {
115 self.chunks.iter().map(|chunk| chunk.len()).sum()
116 }
117}
118
119impl From<Bytes> for FrozenChunkedBytesBuffer {
120 fn from(bytes: Bytes) -> Self {
121 let mut chunks = VecDeque::new();
122 if !bytes.is_empty() {
123 chunks.push_back(bytes);
124 }
125 Self { chunks }
126 }
127}
128
129impl Buf for FrozenChunkedBytesBuffer {
130 fn remaining(&self) -> usize {
131 self.chunks.iter().map(|chunk| chunk.remaining()).sum()
132 }
133
134 fn chunk(&self) -> &[u8] {
135 self.chunks.front().map_or(&[], |chunk| chunk.chunk())
136 }
137
138 fn advance(&mut self, mut cnt: usize) {
139 while cnt > 0 {
140 let chunk = self.chunks.front_mut().expect("no chunks left");
141 let chunk_remaining = chunk.remaining();
142 if cnt < chunk_remaining {
143 chunk.advance(cnt);
144 break;
145 }
146
147 chunk.advance(chunk_remaining);
148 cnt -= chunk_remaining;
149 self.chunks.pop_front();
150 }
151 }
152}
153
154impl Body for FrozenChunkedBytesBuffer {
155 type Data = Bytes;
156 type Error = Infallible;
157
158 fn poll_frame(
159 mut self: Pin<&mut Self>, _: &mut Context<'_>,
160 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
161 Poll::Ready(self.chunks.pop_front().map(|chunk| Ok(Frame::data(chunk))))
162 }
163
164 fn size_hint(&self) -> SizeHint {
165 SizeHint::with_exact(self.len() as u64)
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use bytes::BytesMut;
172 use http_body_util::BodyExt as _;
173 use tokio::io::AsyncWriteExt as _;
174 use tokio_test::{assert_ready, task::spawn as test_spawn};
175
176 use super::*;
177
178 const TEST_CHUNK_SIZE: usize = 16;
179 const TEST_BUF_CHUNK_SIZED: &[u8] = b"hello world!!!!!";
180 const TEST_BUF_LESS_THAN_CHUNK_SIZED: &[u8] = b"hello world!";
181 const TEST_BUF_GREATER_THAN_CHUNK_SIZED: &[u8] = b"hello world, here i come!";
182
183 #[test]
184 fn single_write_fits_within_single_chunk() {
185 let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
186
187 let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_LESS_THAN_CHUNK_SIZED));
189 let result = assert_ready!(fut.poll());
190
191 let n = result.unwrap();
192 assert_eq!(n, TEST_BUF_LESS_THAN_CHUNK_SIZED.len());
193 assert_eq!(chunked_buffer.chunks.len(), 1);
194
195 let total_capacity = chunked_buffer.chunks.len() * TEST_CHUNK_SIZE;
196 assert_eq!(chunked_buffer.remaining_capacity, total_capacity - n);
197 }
198
199 #[test]
200 fn single_write_fits_single_chunk_exactly() {
201 let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
202
203 let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_CHUNK_SIZED));
205 let result = assert_ready!(fut.poll());
206
207 let n = result.unwrap();
208 assert_eq!(n, TEST_BUF_CHUNK_SIZED.len());
209 assert_eq!(chunked_buffer.chunks.len(), 1);
210
211 let total_capacity = chunked_buffer.chunks.len() * TEST_CHUNK_SIZE;
212 assert_eq!(chunked_buffer.remaining_capacity, total_capacity - n);
213 }
214
215 #[test]
216 fn single_write_strides_two_chunks() {
217 let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
218
219 let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_GREATER_THAN_CHUNK_SIZED));
221 let result = assert_ready!(fut.poll());
222
223 let n = result.unwrap();
224 assert_eq!(n, TEST_BUF_GREATER_THAN_CHUNK_SIZED.len());
225 assert_eq!(chunked_buffer.chunks.len(), 2);
226
227 let total_capacity = chunked_buffer.chunks.len() * TEST_CHUNK_SIZE;
228 assert_eq!(chunked_buffer.remaining_capacity, total_capacity - n);
229 }
230
231 #[test]
232 fn two_writes_fit_two_chunks_exactly() {
233 let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
234
235 let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_CHUNK_SIZED));
237 let result = assert_ready!(fut.poll());
238
239 let first_n = result.unwrap();
240 assert_eq!(first_n, TEST_BUF_CHUNK_SIZED.len());
241 assert_eq!(chunked_buffer.chunks.len(), 1);
242 assert_eq!(chunked_buffer.remaining_capacity, 0);
243
244 let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_CHUNK_SIZED));
246 let result = assert_ready!(fut.poll());
247
248 let second_n = result.unwrap();
249 assert_eq!(second_n, TEST_BUF_CHUNK_SIZED.len());
250 assert_eq!(chunked_buffer.chunks.len(), 2);
251 assert_eq!(chunked_buffer.remaining_capacity, 0);
252 }
253
254 #[tokio::test]
255 async fn all_chunks_returned_as_body() {
256 let test_bufs = &[
257 TEST_BUF_LESS_THAN_CHUNK_SIZED,
258 TEST_BUF_CHUNK_SIZED,
259 TEST_BUF_GREATER_THAN_CHUNK_SIZED,
260 ];
261 let test_bufs_total_len = test_bufs.iter().map(|buf| buf.len()).sum::<usize>();
262 let required_chunks = test_bufs_total_len / TEST_CHUNK_SIZE
263 + if test_bufs_total_len % TEST_CHUNK_SIZE > 0 {
264 1
265 } else {
266 0
267 };
268
269 let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
270 let total_capacity = required_chunks * TEST_CHUNK_SIZE;
271
272 let mut expected_aggregated_body = BytesMut::new();
277 let test_bufs = &[
278 TEST_BUF_LESS_THAN_CHUNK_SIZED,
279 TEST_BUF_CHUNK_SIZED,
280 TEST_BUF_GREATER_THAN_CHUNK_SIZED,
281 ];
282 let mut total_written = 0;
283 for test_buf in test_bufs {
284 chunked_buffer.write_all(test_buf).await.unwrap();
285 expected_aggregated_body.put(*test_buf);
286 total_written += test_buf.len();
287 }
288
289 assert_eq!(test_bufs_total_len, total_written);
290 assert_eq!(chunked_buffer.chunks.len(), required_chunks);
291 assert_eq!(chunked_buffer.remaining_capacity, total_capacity - total_written);
292
293 let read_chunked_buffer = chunked_buffer.freeze();
294
295 let actual_aggregated_body = read_chunked_buffer.collect().await.expect("cannot fail").to_bytes();
297
298 assert_eq!(expected_aggregated_body.freeze(), actual_aggregated_body);
299 }
300}