Skip to main content

saluki_io/net/
stream.rs

1use std::{
2    io,
3    net::SocketAddr,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use bytes::BufMut;
9use pin_project::pin_project;
10use tokio::{
11    io::{AsyncRead, AsyncReadExt as _, AsyncWrite, ReadBuf},
12    net::{TcpStream, UdpSocket},
13};
14
15use super::addr::ConnectionAddress;
16#[cfg(unix)]
17use super::{
18    addr::ProcessIdentity,
19    unix::{unix_recvmsg, unixgram_recvmsg},
20};
21
22/// A connection-oriented socket.
23///
24/// This type wraps network sockets that operate in a connection-oriented manner, such as TCP or Unix domain sockets in
25/// stream mode.
26#[pin_project(project = ConnectionProjected)]
27pub enum Connection {
28    /// A TCP socket.
29    Tcp(#[pin] TcpStream, SocketAddr),
30
31    /// A Unix domain socket in stream mode (SOCK_STREAM).
32    #[cfg(unix)]
33    Unix(#[pin] tokio::net::UnixStream),
34}
35
36impl Connection {
37    async fn receive<B: BufMut>(&mut self, buf: &mut B) -> io::Result<(usize, ConnectionAddress)> {
38        match self {
39            Self::Tcp(inner, addr) => inner.read_buf(buf).await.map(|n| (n, (*addr).into())),
40            #[cfg(unix)]
41            Self::Unix(inner) => unix_recvmsg(inner, buf).await,
42        }
43    }
44
45    pub(super) fn remote_addr(&self) -> ConnectionAddress {
46        match self {
47            Self::Tcp(_, addr) => ConnectionAddress::SocketLike(*addr),
48            #[cfg(unix)]
49            Self::Unix(_) => ConnectionAddress::ProcessLike(ProcessIdentity::Unavailable),
50        }
51    }
52}
53
54impl AsyncRead for Connection {
55    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
56        match self.project() {
57            ConnectionProjected::Tcp(inner, _) => inner.poll_read(cx, buf),
58            #[cfg(unix)]
59            ConnectionProjected::Unix(inner) => inner.poll_read(cx, buf),
60        }
61    }
62}
63
64impl AsyncWrite for Connection {
65    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
66        match self.project() {
67            ConnectionProjected::Tcp(inner, _) => inner.poll_write(cx, buf),
68            #[cfg(unix)]
69            ConnectionProjected::Unix(inner) => inner.poll_write(cx, buf),
70        }
71    }
72
73    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
74        match self.project() {
75            ConnectionProjected::Tcp(inner, _) => inner.poll_flush(cx),
76            #[cfg(unix)]
77            ConnectionProjected::Unix(inner) => inner.poll_flush(cx),
78        }
79    }
80
81    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
82        match self.project() {
83            ConnectionProjected::Tcp(inner, _) => inner.poll_shutdown(cx),
84            #[cfg(unix)]
85            ConnectionProjected::Unix(inner) => inner.poll_shutdown(cx),
86        }
87    }
88}
89
90/// A connectionless socket.
91///
92/// This type wraps network sockets that operate in a connectionless manner, such as UDP or Unix domain sockets in
93/// datagram mode.
94enum Connectionless {
95    /// A UDP socket.
96    Udp(UdpSocket),
97
98    /// A Unix domain socket in datagram mode (SOCK_DGRAM).
99    #[cfg(unix)]
100    Unixgram(tokio::net::UnixDatagram),
101}
102
103impl Connectionless {
104    async fn receive<B: BufMut>(&mut self, buf: &mut B) -> io::Result<(usize, ConnectionAddress)> {
105        match self {
106            Self::Udp(inner) => inner.recv_buf_from(buf).await.map(|(n, addr)| (n, addr.into())),
107            #[cfg(unix)]
108            Self::Unixgram(inner) => unixgram_recvmsg(inner, buf).await,
109        }
110    }
111}
112
113enum StreamInner {
114    Connection { socket: Connection },
115    Connectionless { socket: Connectionless },
116}
117
118/// A network stream.
119///
120/// `Stream` provides an abstraction over connectionless and connection-oriented network sockets. In many cases, it's
121/// not required to know the exact socket family (for example, TCP, UDP, Unix domain socket) that's being used, and it can be
122/// beneficial to allow abstracting over the differences to facilitate simpler code.
123///
124/// ## Connection-oriented mode
125///
126/// In connection-oriented mode, the stream is backed by a socket that operates in a connection-oriented manner, which
127/// ensures a reliable, ordered stream of messages to and from the remote peer.
128///
129/// The connection address returned when receiving data _should_ be stable for the life of the `Stream`.
130///
131/// ## Connectionless mode
132///
133/// In connectionless mode, the stream is backed by a socket that operates in a connectionless manner, which doesn't
134/// provide any assurances around reliability and ordering of messages to and from the remote peer. While a stream might
135/// be backed by a Unix domain socket in datagram mode, which _does_ provide reliability of messages, this can't and
136/// shouldn't be relied upon when using `Stream`.
137pub struct Stream {
138    inner: StreamInner,
139}
140
141impl Stream {
142    /// Returns `true` if the stream is connectionless.
143    pub fn is_connectionless(&self) -> bool {
144        matches!(self.inner, StreamInner::Connectionless { .. })
145    }
146
147    /// Receives data from the stream.
148    ///
149    /// On success, returns the number of bytes read and the address from whence the data came.
150    ///
151    /// ## Errors
152    ///
153    /// If the underlying system call fails, an error is returned.
154    pub async fn receive<B: BufMut>(&mut self, buf: &mut B) -> io::Result<(usize, ConnectionAddress)> {
155        match &mut self.inner {
156            StreamInner::Connection { socket } => socket.receive(buf).await,
157            StreamInner::Connectionless { socket } => socket.receive(buf).await,
158        }
159    }
160
161    #[cfg(test)]
162    pub(crate) fn recv_buffer_size(&self) -> io::Result<usize> {
163        match &self.inner {
164            StreamInner::Connection { socket } => match socket {
165                Connection::Tcp(inner, _) => socket2::SockRef::from(inner).recv_buffer_size(),
166                #[cfg(unix)]
167                Connection::Unix(inner) => socket2::SockRef::from(inner).recv_buffer_size(),
168            },
169            StreamInner::Connectionless { socket } => match socket {
170                Connectionless::Udp(inner) => socket2::SockRef::from(inner).recv_buffer_size(),
171                #[cfg(unix)]
172                Connectionless::Unixgram(inner) => socket2::SockRef::from(inner).recv_buffer_size(),
173            },
174        }
175    }
176}
177
178impl From<(TcpStream, SocketAddr)> for Stream {
179    fn from((stream, remote_addr): (TcpStream, SocketAddr)) -> Self {
180        Self {
181            inner: StreamInner::Connection {
182                socket: Connection::Tcp(stream, remote_addr),
183            },
184        }
185    }
186}
187
188impl From<UdpSocket> for Stream {
189    fn from(socket: UdpSocket) -> Self {
190        Self {
191            inner: StreamInner::Connectionless {
192                socket: Connectionless::Udp(socket),
193            },
194        }
195    }
196}
197
198#[cfg(unix)]
199impl From<tokio::net::UnixDatagram> for Stream {
200    fn from(socket: tokio::net::UnixDatagram) -> Self {
201        Self {
202            inner: StreamInner::Connectionless {
203                socket: Connectionless::Unixgram(socket),
204            },
205        }
206    }
207}
208
209#[cfg(unix)]
210impl From<tokio::net::UnixStream> for Stream {
211    fn from(stream: tokio::net::UnixStream) -> Self {
212        Self {
213            inner: StreamInner::Connection {
214                socket: Connection::Unix(stream),
215            },
216        }
217    }
218}