saluki_common/buf/
chunked.rs1use std::{
2 collections::VecDeque,
3 convert::Infallible,
4 io,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use bytes::{buf::UninitSlice, Buf, BufMut, Bytes, BytesMut};
10use http_body::{Body, Frame, SizeHint};
11use tokio::io::AsyncWrite;
12
13pub struct ChunkedBytesBuffer {
22 chunks: VecDeque<BytesMut>,
23 chunk_size: usize,
24 remaining_capacity: usize,
25}
26
27impl ChunkedBytesBuffer {
28 pub fn new(chunk_size: usize) -> Self {
30 Self {
31 chunks: VecDeque::new(),
32 chunk_size,
33 remaining_capacity: 0,
34 }
35 }
36
37 pub fn is_empty(&self) -> bool {
39 self.chunks.is_empty()
40 }
41
42 pub fn len(&self) -> usize {
44 self.chunks.iter().map(|chunk| chunk.remaining()).sum()
45 }
46
47 fn register_new_chunk(&mut self) {
48 self.remaining_capacity += self.chunk_size;
49 self.chunks.push_back(BytesMut::with_capacity(self.chunk_size));
50 }
51
52 fn ensure_capacity_for_write(&mut self) {
53 if self.remaining_capacity == 0 {
54 self.register_new_chunk();
55 }
56 }
57
58 pub fn freeze(self) -> FrozenChunkedBytesBuffer {
62 FrozenChunkedBytesBuffer {
63 chunks: self.chunks.into_iter().map(|chunk| chunk.freeze()).collect(),
64 }
65 }
66}
67
68impl AsyncWrite for ChunkedBytesBuffer {
69 fn poll_write(mut self: Pin<&mut Self>, _: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
70 self.put_slice(buf);
71 Poll::Ready(Ok(buf.len()))
72 }
73
74 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
75 Poll::Ready(Ok(()))
76 }
77
78 fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
79 Poll::Ready(Ok(()))
80 }
81}
82
83unsafe impl BufMut for ChunkedBytesBuffer {
84 fn remaining_mut(&self) -> usize {
85 usize::MAX - self.len()
86 }
87
88 fn chunk_mut(&mut self) -> &mut UninitSlice {
89 self.ensure_capacity_for_write();
90 self.chunks.back_mut().unwrap().spare_capacity_mut().into()
91 }
92
93 unsafe fn advance_mut(&mut self, cnt: usize) {
94 self.chunks.back_mut().unwrap().advance_mut(cnt);
95 self.remaining_capacity -= cnt;
96 }
97}
98
99#[derive(Clone)]
103pub struct FrozenChunkedBytesBuffer {
104 chunks: VecDeque<Bytes>,
105}
106
107impl FrozenChunkedBytesBuffer {
108 pub fn is_empty(&self) -> bool {
110 self.chunks.is_empty()
111 }
112
113 pub fn len(&self) -> usize {
115 self.chunks.iter().map(|chunk| chunk.len()).sum()
116 }
117
118 pub fn into_bytes(mut self) -> Bytes {
122 if self.chunks.len() == 1 {
123 self.chunks.pop_front().unwrap()
124 } else {
125 let mut buf = BytesMut::new();
126 for chunk in self.chunks {
127 buf.extend_from_slice(&chunk);
128 }
129 buf.freeze()
130 }
131 }
132}
133
134impl From<Bytes> for FrozenChunkedBytesBuffer {
135 fn from(bytes: Bytes) -> Self {
136 let mut chunks = VecDeque::new();
137 if !bytes.is_empty() {
138 chunks.push_back(bytes);
139 }
140 Self { chunks }
141 }
142}
143
144impl Buf for FrozenChunkedBytesBuffer {
145 fn remaining(&self) -> usize {
146 self.chunks.iter().map(|chunk| chunk.remaining()).sum()
147 }
148
149 fn chunk(&self) -> &[u8] {
150 self.chunks.front().map_or(&[], |chunk| chunk.chunk())
151 }
152
153 fn advance(&mut self, mut cnt: usize) {
154 while cnt > 0 {
155 let chunk = self.chunks.front_mut().expect("no chunks left");
156 let chunk_remaining = chunk.remaining();
157 if cnt < chunk_remaining {
158 chunk.advance(cnt);
159 break;
160 }
161
162 chunk.advance(chunk_remaining);
163 cnt -= chunk_remaining;
164 self.chunks.pop_front();
165 }
166 }
167}
168
169impl Body for FrozenChunkedBytesBuffer {
170 type Data = Bytes;
171 type Error = Infallible;
172
173 fn poll_frame(
174 mut self: Pin<&mut Self>, _: &mut Context<'_>,
175 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
176 Poll::Ready(self.chunks.pop_front().map(|chunk| Ok(Frame::data(chunk))))
177 }
178
179 fn size_hint(&self) -> SizeHint {
180 SizeHint::with_exact(self.len() as u64)
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use bytes::BytesMut;
187 use http_body_util::BodyExt as _;
188 use tokio::io::AsyncWriteExt as _;
189 use tokio_test::{assert_ready, task::spawn as test_spawn};
190
191 use super::*;
192
193 const TEST_CHUNK_SIZE: usize = 16;
194 const TEST_BUF_CHUNK_SIZED: &[u8] = b"hello world!!!!!";
195 const TEST_BUF_LESS_THAN_CHUNK_SIZED: &[u8] = b"hello world!";
196 const TEST_BUF_GREATER_THAN_CHUNK_SIZED: &[u8] = b"hello world, here i come!";
197
198 #[test]
199 fn single_write_fits_within_single_chunk() {
200 let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
201
202 let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_LESS_THAN_CHUNK_SIZED));
204 let result = assert_ready!(fut.poll());
205
206 let n = result.unwrap();
207 assert_eq!(n, TEST_BUF_LESS_THAN_CHUNK_SIZED.len());
208 assert_eq!(chunked_buffer.chunks.len(), 1);
209
210 let total_capacity = chunked_buffer.chunks.len() * TEST_CHUNK_SIZE;
211 assert_eq!(chunked_buffer.remaining_capacity, total_capacity - n);
212 }
213
214 #[test]
215 fn single_write_fits_single_chunk_exactly() {
216 let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
217
218 let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_CHUNK_SIZED));
220 let result = assert_ready!(fut.poll());
221
222 let n = result.unwrap();
223 assert_eq!(n, TEST_BUF_CHUNK_SIZED.len());
224 assert_eq!(chunked_buffer.chunks.len(), 1);
225
226 let total_capacity = chunked_buffer.chunks.len() * TEST_CHUNK_SIZE;
227 assert_eq!(chunked_buffer.remaining_capacity, total_capacity - n);
228 }
229
230 #[test]
231 fn single_write_strides_two_chunks() {
232 let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
233
234 let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_GREATER_THAN_CHUNK_SIZED));
236 let result = assert_ready!(fut.poll());
237
238 let n = result.unwrap();
239 assert_eq!(n, TEST_BUF_GREATER_THAN_CHUNK_SIZED.len());
240 assert_eq!(chunked_buffer.chunks.len(), 2);
241
242 let total_capacity = chunked_buffer.chunks.len() * TEST_CHUNK_SIZE;
243 assert_eq!(chunked_buffer.remaining_capacity, total_capacity - n);
244 }
245
246 #[test]
247 fn two_writes_fit_two_chunks_exactly() {
248 let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
249
250 let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_CHUNK_SIZED));
252 let result = assert_ready!(fut.poll());
253
254 let first_n = result.unwrap();
255 assert_eq!(first_n, TEST_BUF_CHUNK_SIZED.len());
256 assert_eq!(chunked_buffer.chunks.len(), 1);
257 assert_eq!(chunked_buffer.remaining_capacity, 0);
258
259 let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_CHUNK_SIZED));
261 let result = assert_ready!(fut.poll());
262
263 let second_n = result.unwrap();
264 assert_eq!(second_n, TEST_BUF_CHUNK_SIZED.len());
265 assert_eq!(chunked_buffer.chunks.len(), 2);
266 assert_eq!(chunked_buffer.remaining_capacity, 0);
267 }
268
269 #[tokio::test]
270 async fn all_chunks_returned_as_body() {
271 let test_bufs = &[
272 TEST_BUF_LESS_THAN_CHUNK_SIZED,
273 TEST_BUF_CHUNK_SIZED,
274 TEST_BUF_GREATER_THAN_CHUNK_SIZED,
275 ];
276 let test_bufs_total_len = test_bufs.iter().map(|buf| buf.len()).sum::<usize>();
277 let required_chunks = test_bufs_total_len / TEST_CHUNK_SIZE
278 + if test_bufs_total_len % TEST_CHUNK_SIZE > 0 {
279 1
280 } else {
281 0
282 };
283
284 let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
285 let total_capacity = required_chunks * TEST_CHUNK_SIZE;
286
287 let mut expected_aggregated_body = BytesMut::new();
292 let test_bufs = &[
293 TEST_BUF_LESS_THAN_CHUNK_SIZED,
294 TEST_BUF_CHUNK_SIZED,
295 TEST_BUF_GREATER_THAN_CHUNK_SIZED,
296 ];
297 let mut total_written = 0;
298 for test_buf in test_bufs {
299 chunked_buffer.write_all(test_buf).await.unwrap();
300 expected_aggregated_body.put(*test_buf);
301 total_written += test_buf.len();
302 }
303
304 assert_eq!(test_bufs_total_len, total_written);
305 assert_eq!(chunked_buffer.chunks.len(), required_chunks);
306 assert_eq!(chunked_buffer.remaining_capacity, total_capacity - total_written);
307
308 let read_chunked_buffer = chunked_buffer.freeze();
309
310 let actual_aggregated_body = read_chunked_buffer.collect().await.expect("cannot fail").to_bytes();
312
313 assert_eq!(expected_aggregated_body.freeze(), actual_aggregated_body);
314 }
315}