saluki_common/buf/
chunked.rs1use std::{
2 collections::VecDeque,
3 convert::Infallible,
4 io,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use bytes::{buf::UninitSlice, Buf, BufMut, Bytes, BytesMut};
10use http_body::{Body, Frame, SizeHint};
11use tokio::io::AsyncWrite;
12
13pub struct ChunkedBytesBuffer {
22 chunks: VecDeque<BytesMut>,
23 chunk_size: usize,
24 remaining_capacity: usize,
25}
26
27impl ChunkedBytesBuffer {
28 pub fn new(chunk_size: usize) -> Self {
30 Self {
31 chunks: VecDeque::new(),
32 chunk_size,
33 remaining_capacity: 0,
34 }
35 }
36
37 pub fn is_empty(&self) -> bool {
39 self.chunks.is_empty()
40 }
41
42 pub fn len(&self) -> usize {
44 self.chunks.iter().map(|chunk| chunk.remaining()).sum()
45 }
46
47 fn register_new_chunk(&mut self) {
48 self.remaining_capacity += self.chunk_size;
49 self.chunks.push_back(BytesMut::with_capacity(self.chunk_size));
50 }
51
52 fn ensure_capacity_for_write(&mut self) {
53 if self.remaining_capacity == 0 {
54 self.register_new_chunk();
55 }
56 }
57
58 pub fn freeze(self) -> FrozenChunkedBytesBuffer {
62 FrozenChunkedBytesBuffer {
63 chunks: self.chunks.into_iter().map(|chunk| chunk.freeze()).collect(),
64 }
65 }
66}
67
68impl AsyncWrite for ChunkedBytesBuffer {
69 fn poll_write(mut self: Pin<&mut Self>, _: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
70 self.put_slice(buf);
71 Poll::Ready(Ok(buf.len()))
72 }
73
74 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
75 Poll::Ready(Ok(()))
76 }
77
78 fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
79 Poll::Ready(Ok(()))
80 }
81}
82
83unsafe impl BufMut for ChunkedBytesBuffer {
84 fn remaining_mut(&self) -> usize {
85 usize::MAX - self.len()
86 }
87
88 fn chunk_mut(&mut self) -> &mut UninitSlice {
89 self.ensure_capacity_for_write();
90 self.chunks.back_mut().unwrap().spare_capacity_mut().into()
91 }
92
93 unsafe fn advance_mut(&mut self, cnt: usize) {
94 self.chunks.back_mut().unwrap().advance_mut(cnt);
95 self.remaining_capacity -= cnt;
96 }
97}
98
99#[derive(Clone)]
103pub struct FrozenChunkedBytesBuffer {
104 chunks: VecDeque<Bytes>,
105}
106
107impl FrozenChunkedBytesBuffer {
108 pub fn is_empty(&self) -> bool {
110 self.chunks.is_empty()
111 }
112
113 pub fn len(&self) -> usize {
115 self.chunks.iter().map(|chunk| chunk.len()).sum()
116 }
117}
118
119impl Buf for FrozenChunkedBytesBuffer {
120 fn remaining(&self) -> usize {
121 self.chunks.iter().map(|chunk| chunk.remaining()).sum()
122 }
123
124 fn chunk(&self) -> &[u8] {
125 self.chunks.front().map_or(&[], |chunk| chunk.chunk())
126 }
127
128 fn advance(&mut self, mut cnt: usize) {
129 while cnt > 0 {
130 let chunk = self.chunks.front_mut().expect("no chunks left");
131 let chunk_remaining = chunk.remaining();
132 if cnt < chunk_remaining {
133 chunk.advance(cnt);
134 break;
135 }
136
137 chunk.advance(chunk_remaining);
138 cnt -= chunk_remaining;
139 self.chunks.pop_front();
140 }
141 }
142}
143
144impl Body for FrozenChunkedBytesBuffer {
145 type Data = Bytes;
146 type Error = Infallible;
147
148 fn poll_frame(
149 mut self: Pin<&mut Self>, _: &mut Context<'_>,
150 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
151 Poll::Ready(self.chunks.pop_front().map(|chunk| Ok(Frame::data(chunk))))
152 }
153
154 fn size_hint(&self) -> SizeHint {
155 SizeHint::with_exact(self.len() as u64)
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use bytes::BytesMut;
162 use http_body_util::BodyExt as _;
163 use tokio::io::AsyncWriteExt as _;
164 use tokio_test::{assert_ready, task::spawn as test_spawn};
165
166 use super::*;
167
168 const TEST_CHUNK_SIZE: usize = 16;
169 const TEST_BUF_CHUNK_SIZED: &[u8] = b"hello world!!!!!";
170 const TEST_BUF_LESS_THAN_CHUNK_SIZED: &[u8] = b"hello world!";
171 const TEST_BUF_GREATER_THAN_CHUNK_SIZED: &[u8] = b"hello world, here i come!";
172
173 #[test]
174 fn single_write_fits_within_single_chunk() {
175 let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
176
177 let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_LESS_THAN_CHUNK_SIZED));
179 let result = assert_ready!(fut.poll());
180
181 let n = result.unwrap();
182 assert_eq!(n, TEST_BUF_LESS_THAN_CHUNK_SIZED.len());
183 assert_eq!(chunked_buffer.chunks.len(), 1);
184
185 let total_capacity = chunked_buffer.chunks.len() * TEST_CHUNK_SIZE;
186 assert_eq!(chunked_buffer.remaining_capacity, total_capacity - n);
187 }
188
189 #[test]
190 fn single_write_fits_single_chunk_exactly() {
191 let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
192
193 let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_CHUNK_SIZED));
195 let result = assert_ready!(fut.poll());
196
197 let n = result.unwrap();
198 assert_eq!(n, TEST_BUF_CHUNK_SIZED.len());
199 assert_eq!(chunked_buffer.chunks.len(), 1);
200
201 let total_capacity = chunked_buffer.chunks.len() * TEST_CHUNK_SIZE;
202 assert_eq!(chunked_buffer.remaining_capacity, total_capacity - n);
203 }
204
205 #[test]
206 fn single_write_strides_two_chunks() {
207 let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
208
209 let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_GREATER_THAN_CHUNK_SIZED));
211 let result = assert_ready!(fut.poll());
212
213 let n = result.unwrap();
214 assert_eq!(n, TEST_BUF_GREATER_THAN_CHUNK_SIZED.len());
215 assert_eq!(chunked_buffer.chunks.len(), 2);
216
217 let total_capacity = chunked_buffer.chunks.len() * TEST_CHUNK_SIZE;
218 assert_eq!(chunked_buffer.remaining_capacity, total_capacity - n);
219 }
220
221 #[test]
222 fn two_writes_fit_two_chunks_exactly() {
223 let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
224
225 let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_CHUNK_SIZED));
227 let result = assert_ready!(fut.poll());
228
229 let first_n = result.unwrap();
230 assert_eq!(first_n, TEST_BUF_CHUNK_SIZED.len());
231 assert_eq!(chunked_buffer.chunks.len(), 1);
232 assert_eq!(chunked_buffer.remaining_capacity, 0);
233
234 let mut fut = test_spawn(chunked_buffer.write(TEST_BUF_CHUNK_SIZED));
236 let result = assert_ready!(fut.poll());
237
238 let second_n = result.unwrap();
239 assert_eq!(second_n, TEST_BUF_CHUNK_SIZED.len());
240 assert_eq!(chunked_buffer.chunks.len(), 2);
241 assert_eq!(chunked_buffer.remaining_capacity, 0);
242 }
243
244 #[tokio::test]
245 async fn all_chunks_returned_as_body() {
246 let test_bufs = &[
247 TEST_BUF_LESS_THAN_CHUNK_SIZED,
248 TEST_BUF_CHUNK_SIZED,
249 TEST_BUF_GREATER_THAN_CHUNK_SIZED,
250 ];
251 let test_bufs_total_len = test_bufs.iter().map(|buf| buf.len()).sum::<usize>();
252 let required_chunks = test_bufs_total_len / TEST_CHUNK_SIZE
253 + if test_bufs_total_len % TEST_CHUNK_SIZE > 0 {
254 1
255 } else {
256 0
257 };
258
259 let mut chunked_buffer = ChunkedBytesBuffer::new(TEST_CHUNK_SIZE);
260 let total_capacity = required_chunks * TEST_CHUNK_SIZE;
261
262 let mut expected_aggregated_body = BytesMut::new();
267 let test_bufs = &[
268 TEST_BUF_LESS_THAN_CHUNK_SIZED,
269 TEST_BUF_CHUNK_SIZED,
270 TEST_BUF_GREATER_THAN_CHUNK_SIZED,
271 ];
272 let mut total_written = 0;
273 for test_buf in test_bufs {
274 chunked_buffer.write_all(test_buf).await.unwrap();
275 expected_aggregated_body.put(*test_buf);
276 total_written += test_buf.len();
277 }
278
279 assert_eq!(test_bufs_total_len, total_written);
280 assert_eq!(chunked_buffer.chunks.len(), required_chunks);
281 assert_eq!(chunked_buffer.remaining_capacity, total_capacity - total_written);
282
283 let read_chunked_buffer = chunked_buffer.freeze();
284
285 let actual_aggregated_body = read_chunked_buffer.collect().await.expect("cannot fail").to_bytes();
287
288 assert_eq!(expected_aggregated_body.freeze(), actual_aggregated_body);
289 }
290}