1use std::{
2 io,
3 net::SocketAddr,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use bytes::BufMut;
9use pin_project::pin_project;
10use tokio::{
11 io::{AsyncRead, AsyncReadExt as _, AsyncWrite, ReadBuf},
12 net::{TcpStream, UdpSocket},
13};
14
15use super::{
16 addr::ConnectionAddress,
17 unix::{unix_recvmsg, unixgram_recvmsg},
18};
19
20#[pin_project(project = ConnectionProjected)]
25pub enum Connection {
26 Tcp(#[pin] TcpStream, SocketAddr),
28
29 #[cfg(unix)]
31 Unix(#[pin] tokio::net::UnixStream),
32}
33
34impl Connection {
35 async fn receive<B: BufMut>(&mut self, buf: &mut B) -> io::Result<(usize, ConnectionAddress)> {
36 match self {
37 Self::Tcp(inner, addr) => inner.read_buf(buf).await.map(|n| (n, (*addr).into())),
38 #[cfg(unix)]
39 Self::Unix(inner) => unix_recvmsg(inner, buf).await,
40 }
41 }
42
43 pub(super) fn remote_addr(&self) -> ConnectionAddress {
44 match self {
45 Self::Tcp(_, addr) => ConnectionAddress::SocketLike(*addr),
46 #[cfg(unix)]
47 Self::Unix(_) => ConnectionAddress::ProcessLike(None),
48 }
49 }
50}
51
52impl AsyncRead for Connection {
53 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
54 match self.project() {
55 ConnectionProjected::Tcp(inner, _) => inner.poll_read(cx, buf),
56 #[cfg(unix)]
57 ConnectionProjected::Unix(inner) => inner.poll_read(cx, buf),
58 }
59 }
60}
61
62impl AsyncWrite for Connection {
63 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
64 match self.project() {
65 ConnectionProjected::Tcp(inner, _) => inner.poll_write(cx, buf),
66 #[cfg(unix)]
67 ConnectionProjected::Unix(inner) => inner.poll_write(cx, buf),
68 }
69 }
70
71 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
72 match self.project() {
73 ConnectionProjected::Tcp(inner, _) => inner.poll_flush(cx),
74 #[cfg(unix)]
75 ConnectionProjected::Unix(inner) => inner.poll_flush(cx),
76 }
77 }
78
79 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
80 match self.project() {
81 ConnectionProjected::Tcp(inner, _) => inner.poll_shutdown(cx),
82 #[cfg(unix)]
83 ConnectionProjected::Unix(inner) => inner.poll_shutdown(cx),
84 }
85 }
86}
87
88enum Connectionless {
93 Udp(UdpSocket),
95
96 #[cfg(unix)]
98 Unixgram(tokio::net::UnixDatagram),
99}
100
101impl Connectionless {
102 async fn receive<B: BufMut>(&mut self, buf: &mut B) -> io::Result<(usize, ConnectionAddress)> {
103 match self {
104 Self::Udp(inner) => inner.recv_buf_from(buf).await.map(|(n, addr)| (n, addr.into())),
105 #[cfg(unix)]
106 Self::Unixgram(inner) => unixgram_recvmsg(inner, buf).await,
107 }
108 }
109}
110
111enum StreamInner {
112 Connection { socket: Connection },
113 Connectionless { socket: Connectionless },
114}
115
116pub struct Stream {
136 inner: StreamInner,
137}
138
139impl Stream {
140 pub fn is_connectionless(&self) -> bool {
142 matches!(self.inner, StreamInner::Connectionless { .. })
143 }
144
145 pub async fn receive<B: BufMut>(&mut self, buf: &mut B) -> io::Result<(usize, ConnectionAddress)> {
153 match &mut self.inner {
154 StreamInner::Connection { socket } => socket.receive(buf).await,
155 StreamInner::Connectionless { socket } => socket.receive(buf).await,
156 }
157 }
158}
159
160impl From<(TcpStream, SocketAddr)> for Stream {
161 fn from((stream, remote_addr): (TcpStream, SocketAddr)) -> Self {
162 Self {
163 inner: StreamInner::Connection {
164 socket: Connection::Tcp(stream, remote_addr),
165 },
166 }
167 }
168}
169
170impl From<UdpSocket> for Stream {
171 fn from(socket: UdpSocket) -> Self {
172 Self {
173 inner: StreamInner::Connectionless {
174 socket: Connectionless::Udp(socket),
175 },
176 }
177 }
178}
179
180#[cfg(unix)]
181impl From<tokio::net::UnixDatagram> for Stream {
182 fn from(socket: tokio::net::UnixDatagram) -> Self {
183 Self {
184 inner: StreamInner::Connectionless {
185 socket: Connectionless::Unixgram(socket),
186 },
187 }
188 }
189}
190
191#[cfg(unix)]
192impl From<tokio::net::UnixStream> for Stream {
193 fn from(stream: tokio::net::UnixStream) -> Self {
194 Self {
195 inner: StreamInner::Connection {
196 socket: Connection::Unix(stream),
197 },
198 }
199 }
200}