1use std::{
2 io,
3 net::SocketAddr,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use bytes::BufMut;
9use pin_project::pin_project;
10use tokio::{
11 io::{AsyncRead, AsyncReadExt as _, AsyncWrite, ReadBuf},
12 net::{TcpStream, UdpSocket},
13};
14
15use super::addr::ConnectionAddress;
16#[cfg(unix)]
17use super::{
18 addr::ProcessIdentity,
19 unix::{unix_recvmsg, unixgram_recvmsg},
20};
21
22#[pin_project(project = ConnectionProjected)]
27pub enum Connection {
28 Tcp(#[pin] TcpStream, SocketAddr),
30
31 #[cfg(unix)]
33 Unix(#[pin] tokio::net::UnixStream),
34}
35
36impl Connection {
37 async fn receive<B: BufMut>(&mut self, buf: &mut B) -> io::Result<(usize, ConnectionAddress)> {
38 match self {
39 Self::Tcp(inner, addr) => inner.read_buf(buf).await.map(|n| (n, (*addr).into())),
40 #[cfg(unix)]
41 Self::Unix(inner) => unix_recvmsg(inner, buf).await,
42 }
43 }
44
45 pub(super) fn remote_addr(&self) -> ConnectionAddress {
46 match self {
47 Self::Tcp(_, addr) => ConnectionAddress::SocketLike(*addr),
48 #[cfg(unix)]
49 Self::Unix(_) => ConnectionAddress::ProcessLike(ProcessIdentity::Unavailable),
50 }
51 }
52}
53
54impl AsyncRead for Connection {
55 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
56 match self.project() {
57 ConnectionProjected::Tcp(inner, _) => inner.poll_read(cx, buf),
58 #[cfg(unix)]
59 ConnectionProjected::Unix(inner) => inner.poll_read(cx, buf),
60 }
61 }
62}
63
64impl AsyncWrite for Connection {
65 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
66 match self.project() {
67 ConnectionProjected::Tcp(inner, _) => inner.poll_write(cx, buf),
68 #[cfg(unix)]
69 ConnectionProjected::Unix(inner) => inner.poll_write(cx, buf),
70 }
71 }
72
73 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
74 match self.project() {
75 ConnectionProjected::Tcp(inner, _) => inner.poll_flush(cx),
76 #[cfg(unix)]
77 ConnectionProjected::Unix(inner) => inner.poll_flush(cx),
78 }
79 }
80
81 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
82 match self.project() {
83 ConnectionProjected::Tcp(inner, _) => inner.poll_shutdown(cx),
84 #[cfg(unix)]
85 ConnectionProjected::Unix(inner) => inner.poll_shutdown(cx),
86 }
87 }
88}
89
90enum Connectionless {
95 Udp(UdpSocket),
97
98 #[cfg(unix)]
100 Unixgram(tokio::net::UnixDatagram),
101}
102
103impl Connectionless {
104 async fn receive<B: BufMut>(&mut self, buf: &mut B) -> io::Result<(usize, ConnectionAddress)> {
105 match self {
106 Self::Udp(inner) => inner.recv_buf_from(buf).await.map(|(n, addr)| (n, addr.into())),
107 #[cfg(unix)]
108 Self::Unixgram(inner) => unixgram_recvmsg(inner, buf).await,
109 }
110 }
111}
112
113enum StreamInner {
114 Connection { socket: Connection },
115 Connectionless { socket: Connectionless },
116}
117
118pub struct Stream {
138 inner: StreamInner,
139}
140
141impl Stream {
142 pub fn is_connectionless(&self) -> bool {
144 matches!(self.inner, StreamInner::Connectionless { .. })
145 }
146
147 pub async fn receive<B: BufMut>(&mut self, buf: &mut B) -> io::Result<(usize, ConnectionAddress)> {
155 match &mut self.inner {
156 StreamInner::Connection { socket } => socket.receive(buf).await,
157 StreamInner::Connectionless { socket } => socket.receive(buf).await,
158 }
159 }
160
161 #[cfg(test)]
162 pub(crate) fn recv_buffer_size(&self) -> io::Result<usize> {
163 match &self.inner {
164 StreamInner::Connection { socket } => match socket {
165 Connection::Tcp(inner, _) => socket2::SockRef::from(inner).recv_buffer_size(),
166 #[cfg(unix)]
167 Connection::Unix(inner) => socket2::SockRef::from(inner).recv_buffer_size(),
168 },
169 StreamInner::Connectionless { socket } => match socket {
170 Connectionless::Udp(inner) => socket2::SockRef::from(inner).recv_buffer_size(),
171 #[cfg(unix)]
172 Connectionless::Unixgram(inner) => socket2::SockRef::from(inner).recv_buffer_size(),
173 },
174 }
175 }
176}
177
178impl From<(TcpStream, SocketAddr)> for Stream {
179 fn from((stream, remote_addr): (TcpStream, SocketAddr)) -> Self {
180 Self {
181 inner: StreamInner::Connection {
182 socket: Connection::Tcp(stream, remote_addr),
183 },
184 }
185 }
186}
187
188impl From<UdpSocket> for Stream {
189 fn from(socket: UdpSocket) -> Self {
190 Self {
191 inner: StreamInner::Connectionless {
192 socket: Connectionless::Udp(socket),
193 },
194 }
195 }
196}
197
198#[cfg(unix)]
199impl From<tokio::net::UnixDatagram> for Stream {
200 fn from(socket: tokio::net::UnixDatagram) -> Self {
201 Self {
202 inner: StreamInner::Connectionless {
203 socket: Connectionless::Unixgram(socket),
204 },
205 }
206 }
207}
208
209#[cfg(unix)]
210impl From<tokio::net::UnixStream> for Stream {
211 fn from(stream: tokio::net::UnixStream) -> Self {
212 Self {
213 inner: StreamInner::Connection {
214 socket: Connection::Unix(stream),
215 },
216 }
217 }
218}