saluki_io/net/
stream.rs

1use std::{
2    io,
3    net::SocketAddr,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use bytes::BufMut;
9use pin_project::pin_project;
10use tokio::{
11    io::{AsyncRead, AsyncReadExt as _, AsyncWrite, ReadBuf},
12    net::{TcpStream, UdpSocket},
13};
14
15use super::{
16    addr::ConnectionAddress,
17    unix::{unix_recvmsg, unixgram_recvmsg},
18};
19
20/// A connection-oriented socket.
21///
22/// This type wraps network sockets that operate in a connection-oriented manner, such as TCP or Unix domain sockets in
23/// stream mode.
24#[pin_project(project = ConnectionProjected)]
25pub enum Connection {
26    /// A TCP socket.
27    Tcp(#[pin] TcpStream, SocketAddr),
28
29    /// A Unix domain socket in stream mode (SOCK_STREAM).
30    #[cfg(unix)]
31    Unix(#[pin] tokio::net::UnixStream),
32}
33
34impl Connection {
35    async fn receive<B: BufMut>(&mut self, buf: &mut B) -> io::Result<(usize, ConnectionAddress)> {
36        match self {
37            Self::Tcp(inner, addr) => inner.read_buf(buf).await.map(|n| (n, (*addr).into())),
38            #[cfg(unix)]
39            Self::Unix(inner) => unix_recvmsg(inner, buf).await,
40        }
41    }
42
43    pub(super) fn remote_addr(&self) -> ConnectionAddress {
44        match self {
45            Self::Tcp(_, addr) => ConnectionAddress::SocketLike(*addr),
46            #[cfg(unix)]
47            Self::Unix(_) => ConnectionAddress::ProcessLike(None),
48        }
49    }
50}
51
52impl AsyncRead for Connection {
53    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
54        match self.project() {
55            ConnectionProjected::Tcp(inner, _) => inner.poll_read(cx, buf),
56            #[cfg(unix)]
57            ConnectionProjected::Unix(inner) => inner.poll_read(cx, buf),
58        }
59    }
60}
61
62impl AsyncWrite for Connection {
63    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
64        match self.project() {
65            ConnectionProjected::Tcp(inner, _) => inner.poll_write(cx, buf),
66            #[cfg(unix)]
67            ConnectionProjected::Unix(inner) => inner.poll_write(cx, buf),
68        }
69    }
70
71    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
72        match self.project() {
73            ConnectionProjected::Tcp(inner, _) => inner.poll_flush(cx),
74            #[cfg(unix)]
75            ConnectionProjected::Unix(inner) => inner.poll_flush(cx),
76        }
77    }
78
79    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
80        match self.project() {
81            ConnectionProjected::Tcp(inner, _) => inner.poll_shutdown(cx),
82            #[cfg(unix)]
83            ConnectionProjected::Unix(inner) => inner.poll_shutdown(cx),
84        }
85    }
86}
87
88/// A connectionless socket.
89///
90/// This type wraps network sockets that operate in a connectionless manner, such as UDP or Unix domain sockets in
91/// datagram mode.
92enum Connectionless {
93    /// A UDP socket.
94    Udp(UdpSocket),
95
96    /// A Unix domain socket in datagram mode (SOCK_DGRAM).
97    #[cfg(unix)]
98    Unixgram(tokio::net::UnixDatagram),
99}
100
101impl Connectionless {
102    async fn receive<B: BufMut>(&mut self, buf: &mut B) -> io::Result<(usize, ConnectionAddress)> {
103        match self {
104            Self::Udp(inner) => inner.recv_buf_from(buf).await.map(|(n, addr)| (n, addr.into())),
105            #[cfg(unix)]
106            Self::Unixgram(inner) => unixgram_recvmsg(inner, buf).await,
107        }
108    }
109}
110
111enum StreamInner {
112    Connection { socket: Connection },
113    Connectionless { socket: Connectionless },
114}
115
116/// A network stream.
117///
118/// `Stream` provides an abstraction over connectionless and connection-oriented network sockets. In many cases, it is
119/// not required to know the exact socket family (e.g. TCP, UDP, Unix domain socket) that is being used, and it can be
120/// beneficial to allow abstracting over the differences to facilitate simpler code.
121///
122/// ## Connection-oriented mode
123///
124/// In connection-oriented mode, the stream is backed by a socket that operates in a connection-oriented manner, which
125/// ensures a reliable, ordered stream of messages to and from the remote peer.
126///
127/// The connection address returned when receiving data _should_ be stable for the life of the `Stream`.
128///
129/// ## Connectionless mode
130///
131/// In connectionless mode, the stream is backed by a socket that operates in a connectionless manner, which does not
132/// provide any assurances around reliability and ordering of messages to and from the remote peer. While a stream might
133/// be backed by a Unix domain socket in datagram mode, which _does_ provide reliability of messages, this cannot and
134/// should not be relied upon when using `Stream`.
135pub struct Stream {
136    inner: StreamInner,
137}
138
139impl Stream {
140    /// Returns `true` if the stream is connectionless.
141    pub fn is_connectionless(&self) -> bool {
142        matches!(self.inner, StreamInner::Connectionless { .. })
143    }
144
145    /// Receives data from the stream.
146    ///
147    /// On success, returns the number of bytes read and the address from whence the data came.
148    ///
149    /// ## Errors
150    ///
151    /// If the underlying system call fails, an error is returned.
152    pub async fn receive<B: BufMut>(&mut self, buf: &mut B) -> io::Result<(usize, ConnectionAddress)> {
153        match &mut self.inner {
154            StreamInner::Connection { socket } => socket.receive(buf).await,
155            StreamInner::Connectionless { socket } => socket.receive(buf).await,
156        }
157    }
158}
159
160impl From<(TcpStream, SocketAddr)> for Stream {
161    fn from((stream, remote_addr): (TcpStream, SocketAddr)) -> Self {
162        Self {
163            inner: StreamInner::Connection {
164                socket: Connection::Tcp(stream, remote_addr),
165            },
166        }
167    }
168}
169
170impl From<UdpSocket> for Stream {
171    fn from(socket: UdpSocket) -> Self {
172        Self {
173            inner: StreamInner::Connectionless {
174                socket: Connectionless::Udp(socket),
175            },
176        }
177    }
178}
179
180#[cfg(unix)]
181impl From<tokio::net::UnixDatagram> for Stream {
182    fn from(socket: tokio::net::UnixDatagram) -> Self {
183        Self {
184            inner: StreamInner::Connectionless {
185                socket: Connectionless::Unixgram(socket),
186            },
187        }
188    }
189}
190
191#[cfg(unix)]
192impl From<tokio::net::UnixStream> for Stream {
193    fn from(stream: tokio::net::UnixStream) -> Self {
194        Self {
195            inner: StreamInner::Connection {
196                socket: Connection::Unix(stream),
197            },
198        }
199    }
200}