1use std::{
2 io,
3 net::SocketAddr,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use bytes::BufMut;
9use pin_project::pin_project;
10use tokio::{
11 io::{AsyncRead, AsyncReadExt as _, AsyncWrite, ReadBuf},
12 net::{TcpStream, UdpSocket},
13};
14
15use super::{
16 addr::ConnectionAddress,
17 unix::{unix_recvmsg, unixgram_recvmsg},
18};
19
20#[pin_project(project = ConnectionProjected)]
25pub enum Connection {
26 Tcp(#[pin] TcpStream, SocketAddr),
28
29 #[cfg(unix)]
31 Unix(#[pin] tokio::net::UnixStream),
32}
33
34impl Connection {
35 async fn receive<B: BufMut>(&mut self, buf: &mut B) -> io::Result<(usize, ConnectionAddress)> {
36 match self {
37 Self::Tcp(inner, addr) => inner.read_buf(buf).await.map(|n| (n, (*addr).into())),
38 #[cfg(unix)]
39 Self::Unix(inner) => unix_recvmsg(inner, buf).await,
40 }
41 }
42
43 pub(super) fn remote_addr(&self) -> ConnectionAddress {
44 match self {
45 Self::Tcp(_, addr) => ConnectionAddress::SocketLike(*addr),
46 #[cfg(unix)]
47 Self::Unix(_) => ConnectionAddress::ProcessLike(None),
48 }
49 }
50}
51
52impl AsyncRead for Connection {
53 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
54 match self.project() {
55 ConnectionProjected::Tcp(inner, _) => inner.poll_read(cx, buf),
56 #[cfg(unix)]
57 ConnectionProjected::Unix(inner) => inner.poll_read(cx, buf),
58 }
59 }
60}
61
62impl AsyncWrite for Connection {
63 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
64 match self.project() {
65 ConnectionProjected::Tcp(inner, _) => inner.poll_write(cx, buf),
66 #[cfg(unix)]
67 ConnectionProjected::Unix(inner) => inner.poll_write(cx, buf),
68 }
69 }
70
71 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
72 match self.project() {
73 ConnectionProjected::Tcp(inner, _) => inner.poll_flush(cx),
74 #[cfg(unix)]
75 ConnectionProjected::Unix(inner) => inner.poll_flush(cx),
76 }
77 }
78
79 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
80 match self.project() {
81 ConnectionProjected::Tcp(inner, _) => inner.poll_shutdown(cx),
82 #[cfg(unix)]
83 ConnectionProjected::Unix(inner) => inner.poll_shutdown(cx),
84 }
85 }
86}
87
88enum Connectionless {
93 Udp(UdpSocket),
95
96 #[cfg(unix)]
98 Unixgram(tokio::net::UnixDatagram),
99}
100
101impl Connectionless {
102 async fn receive<B: BufMut>(&mut self, buf: &mut B) -> io::Result<(usize, ConnectionAddress)> {
103 match self {
104 Self::Udp(inner) => inner.recv_buf_from(buf).await.map(|(n, addr)| (n, addr.into())),
105 #[cfg(unix)]
106 Self::Unixgram(inner) => unixgram_recvmsg(inner, buf).await,
107 }
108 }
109}
110
111enum StreamInner {
112 Connection { socket: Connection },
113 Connectionless { socket: Connectionless },
114}
115
116pub struct Stream {
136 inner: StreamInner,
137}
138
139impl Stream {
140 pub fn is_connectionless(&self) -> bool {
142 matches!(self.inner, StreamInner::Connectionless { .. })
143 }
144
145 pub async fn receive<B: BufMut>(&mut self, buf: &mut B) -> io::Result<(usize, ConnectionAddress)> {
153 match &mut self.inner {
154 StreamInner::Connection { socket } => socket.receive(buf).await,
155 StreamInner::Connectionless { socket } => socket.receive(buf).await,
156 }
157 }
158
159 #[cfg(test)]
160 pub(crate) fn recv_buffer_size(&self) -> io::Result<usize> {
161 match &self.inner {
162 StreamInner::Connection { socket } => match socket {
163 Connection::Tcp(inner, _) => socket2::SockRef::from(inner).recv_buffer_size(),
164 #[cfg(unix)]
165 Connection::Unix(inner) => socket2::SockRef::from(inner).recv_buffer_size(),
166 },
167 StreamInner::Connectionless { socket } => match socket {
168 Connectionless::Udp(inner) => socket2::SockRef::from(inner).recv_buffer_size(),
169 #[cfg(unix)]
170 Connectionless::Unixgram(inner) => socket2::SockRef::from(inner).recv_buffer_size(),
171 },
172 }
173 }
174}
175
176impl From<(TcpStream, SocketAddr)> for Stream {
177 fn from((stream, remote_addr): (TcpStream, SocketAddr)) -> Self {
178 Self {
179 inner: StreamInner::Connection {
180 socket: Connection::Tcp(stream, remote_addr),
181 },
182 }
183 }
184}
185
186impl From<UdpSocket> for Stream {
187 fn from(socket: UdpSocket) -> Self {
188 Self {
189 inner: StreamInner::Connectionless {
190 socket: Connectionless::Udp(socket),
191 },
192 }
193 }
194}
195
196#[cfg(unix)]
197impl From<tokio::net::UnixDatagram> for Stream {
198 fn from(socket: tokio::net::UnixDatagram) -> Self {
199 Self {
200 inner: StreamInner::Connectionless {
201 socket: Connectionless::Unixgram(socket),
202 },
203 }
204 }
205}
206
207#[cfg(unix)]
208impl From<tokio::net::UnixStream> for Stream {
209 fn from(stream: tokio::net::UnixStream) -> Self {
210 Self {
211 inner: StreamInner::Connection {
212 socket: Connection::Unix(stream),
213 },
214 }
215 }
216}