Skip to main content

saluki_io/net/
addr.rs

1use std::{
2    fmt,
3    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4},
4    path::{Path, PathBuf},
5};
6
7use axum::extract::connect_info::Connected;
8use serde::Deserialize;
9use url::Url;
10
11use super::Connection;
12
13/// A listen address.
14///
15/// Listen addresses are used to bind listeners to specific local addresses and ports, and multiple address families and
16/// protocols are supported. In textual form, listen addresses are represented as URLs, with the scheme indicating the
17/// protocol and the authority/path representing the address to listen on.
18///
19/// # Examples
20///
21/// - `tcp://127.0.0.1:6789` (listen on IPv4 loopback, TCP port 6789)
22/// - `udp://[::1]:53` (listen on IPv6 loopback, UDP port 53)
23/// - `unixgram:///tmp/app.socket` (listen on a Unix datagram socket at `/tmp/app.socket`)
24/// - `unix:///tmp/app.socket` (listen on a Unix stream socket at `/tmp/app.socket`)
25#[derive(Clone, Debug, Deserialize)]
26#[serde(try_from = "String")]
27pub enum ListenAddress {
28    /// A TCP listen address.
29    Tcp(SocketAddr),
30
31    /// A UDP listen address.
32    Udp(SocketAddr),
33
34    /// A Unix datagram listen address.
35    Unixgram(PathBuf),
36
37    /// A Unix stream listen address.
38    Unix(PathBuf),
39}
40
41impl ListenAddress {
42    /// Creates a TCP address for the given port that listens on all interfaces.
43    pub const fn any_tcp(port: u16) -> Self {
44        Self::Tcp(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port)))
45    }
46
47    /// Returns the socket type of the listen address.
48    pub const fn listener_type(&self) -> &'static str {
49        match self {
50            Self::Tcp(_) => "tcp",
51            Self::Udp(_) => "udp",
52            Self::Unixgram(_) => "unixgram",
53            Self::Unix(_) => "unix",
54        }
55    }
56
57    /// Returns a socket address that can be used to connect to the configured listen address with a bias for local
58    /// clients.
59    ///
60    /// When the listen address is a TCP or UDP address, this method returns a socket address that can be used to
61    /// connect to the listener bound to this listen address, such that if the listen address is unspecified
62    /// (`0.0.0.0`), the client will connect locally using `localhost`. When the listen address isn't unspecified or
63    /// already uses `localhost`, this method returns the listen address as-is.
64    ///
65    /// If the address is a Unix domain socket, this method returns `None`.
66    pub fn as_local_connect_addr(&self) -> Option<SocketAddr> {
67        match self {
68            Self::Tcp(addr) | Self::Udp(addr) => {
69                let mut connect_addr = *addr;
70                if connect_addr.ip().is_unspecified() {
71                    let localhost_ip = match connect_addr.is_ipv4() {
72                        true => IpAddr::V4(Ipv4Addr::LOCALHOST),
73                        false => IpAddr::V6(Ipv6Addr::LOCALHOST),
74                    };
75
76                    connect_addr.set_ip(localhost_ip);
77                }
78
79                Some(connect_addr)
80            }
81            // TODO: why did i do this? it's totally possible to connect to a unix domain socket locally...
82            // in fact, it's kind of the only way to connect to a unix domain socket :thonk:
83            Self::Unixgram(_) => None,
84            Self::Unix(_) => None,
85        }
86    }
87
88    /// Returns the Unix domain socket path if the address is a Unix domain socket in SOCK_STREAM mode.
89    ///
90    /// Returns `None` otherwise.
91    pub fn as_unix_stream_path(&self) -> Option<&Path> {
92        match self {
93            Self::Unix(path) => Some(path),
94            _ => None,
95        }
96    }
97}
98
99impl fmt::Display for ListenAddress {
100    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101        match self {
102            Self::Tcp(addr) => write!(f, "tcp://{}", addr),
103            Self::Udp(addr) => write!(f, "udp://{}", addr),
104            Self::Unixgram(path) => write!(f, "unixgram://{}", path.display()),
105            Self::Unix(path) => write!(f, "unix://{}", path.display()),
106        }
107    }
108}
109
110impl TryFrom<String> for ListenAddress {
111    type Error = String;
112
113    fn try_from(value: String) -> Result<Self, Self::Error> {
114        Self::try_from(value.as_str())
115    }
116}
117
118impl<'a> TryFrom<&'a str> for ListenAddress {
119    type Error = String;
120
121    fn try_from(value: &'a str) -> Result<Self, Self::Error> {
122        let url = match Url::parse(value) {
123            Ok(url) => url,
124            Err(e) => match e {
125                url::ParseError::RelativeUrlWithoutBase => {
126                    Url::parse(&format!("unixgram://{}", value)).map_err(|e| e.to_string())?
127                }
128                _ => return Err(e.to_string()),
129            },
130        };
131
132        match url.scheme() {
133            "tcp" => {
134                let mut socket_addresses = url.socket_addrs(|| None).map_err(|e| e.to_string())?;
135                if socket_addresses.is_empty() {
136                    Err("listen address must resolve to at least one valid IP address/port pair".to_string())
137                } else {
138                    Ok(Self::Tcp(socket_addresses.swap_remove(0)))
139                }
140            }
141            "udp" => {
142                let mut socket_addresses = url.socket_addrs(|| None).map_err(|e| e.to_string())?;
143                if socket_addresses.is_empty() {
144                    Err("listen address must resolve to at least one valid IP address/port pair".to_string())
145                } else {
146                    Ok(Self::Udp(socket_addresses.swap_remove(0)))
147                }
148            }
149            "unixgram" => {
150                let path = url.path();
151                if path.is_empty() {
152                    return Err("socket path cannot be empty".to_string());
153                }
154
155                let path_buf = PathBuf::from(path);
156                if !path_buf.is_absolute() {
157                    return Err("socket path must be absolute".to_string());
158                }
159
160                Ok(Self::Unixgram(path_buf))
161            }
162            "unix" => {
163                let path = url.path();
164                if path.is_empty() {
165                    return Err("socket path cannot be empty".to_string());
166                }
167
168                let path_buf = PathBuf::from(path);
169                if !path_buf.is_absolute() {
170                    return Err("socket path must be absolute".to_string());
171                }
172
173                Ok(Self::Unix(path_buf))
174            }
175            scheme => Err(format!("unknown/unsupported address scheme '{}'", scheme)),
176        }
177    }
178}
179
180/// Process credentials for a Unix domain socket connection.
181///
182/// When dealing with Unix domain sockets, they can be configured such that the "process credentials" of the remote peer
183/// are sent as part of each received message. These "credentials" are the process ID of the remote peer, and the user
184/// ID and group ID that the process is running as.
185///
186/// In some cases, this information can be useful for identifying the remote peer and enriching the received data in an
187/// automatic way.
188#[derive(Clone)]
189pub struct ProcessCredentials {
190    /// Process ID of the remote peer.
191    pub pid: i32,
192
193    /// User ID of the remote peer process.
194    pub uid: u32,
195
196    /// Group ID of the remote peer process.
197    pub gid: u32,
198}
199
200/// Reason UDS process credential detection failed.
201#[derive(Clone, Copy)]
202pub enum ProcessCredentialsError {
203    /// Ancillary data was present but didn't contain usable process credentials.
204    InvalidCredentials,
205
206    /// Process credentials were present, but the PID was zero.
207    ZeroPid,
208
209    /// UDS process credential detection isn't supported on this platform.
210    UnsupportedPlatform,
211}
212
213impl ProcessCredentialsError {
214    /// Returns a concise identifier for the failure reason.
215    pub const fn identifier(&self) -> &'static str {
216        match self {
217            Self::InvalidCredentials => "invalid-credentials",
218            Self::ZeroPid => "zero-pid",
219            Self::UnsupportedPlatform => "unsupported-platform",
220        }
221    }
222}
223
224impl fmt::Display for ProcessCredentialsError {
225    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226        match self {
227            Self::InvalidCredentials => write!(f, "invalid process credentials"),
228            Self::ZeroPid => write!(f, "process credential PID is zero"),
229            Self::UnsupportedPlatform => write!(f, "process credentials are unsupported on this platform"),
230        }
231    }
232}
233
234/// Process identity associated with a Unix domain socket peer.
235#[derive(Clone)]
236pub enum ProcessIdentity {
237    /// Process credentials were detected.
238    Credentials(ProcessCredentials),
239
240    /// Process credential detection failed.
241    Error(ProcessCredentialsError),
242
243    /// Process identity isn't available for this peer.
244    Unavailable,
245}
246
247impl ProcessIdentity {
248    /// Returns process credentials, if they were detected.
249    pub fn credentials(&self) -> Option<&ProcessCredentials> {
250        match self {
251            Self::Credentials(creds) => Some(creds),
252            Self::Error(_) | Self::Unavailable => None,
253        }
254    }
255
256    /// Returns `true` if process credential detection failed.
257    pub const fn is_error(&self) -> bool {
258        matches!(self, Self::Error(_))
259    }
260}
261
262/// Connection address.
263///
264/// A generic representation of the address of a remote peer. This can either be a typical socket address (used for
265/// IPv4/IPv6), or potentially the process credentials of a Unix domain socket connection.
266#[derive(Clone)]
267pub enum ConnectionAddress {
268    /// A socket-like address.
269    SocketLike(SocketAddr),
270
271    /// A process-like address.
272    ProcessLike(ProcessIdentity),
273}
274
275impl fmt::Display for ConnectionAddress {
276    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
277        match self {
278            Self::SocketLike(addr) => write!(f, "{}", addr),
279            Self::ProcessLike(identity) => match identity {
280                ProcessIdentity::Credentials(creds) => {
281                    write!(f, "<pid={} uid={} gid={}>", creds.pid, creds.uid, creds.gid)
282                }
283                ProcessIdentity::Error(error) => write!(f, "<origin-detection-error: {}>", error.identifier()),
284                ProcessIdentity::Unavailable => write!(f, "<no-origin>"),
285            },
286        }
287    }
288}
289
290impl ConnectionAddress {
291    /// Returns process credentials for a Unix domain socket peer, if available.
292    pub fn process_credentials(&self) -> Option<&ProcessCredentials> {
293        match self {
294            Self::ProcessLike(identity) => identity.credentials(),
295            Self::SocketLike(_) => None,
296        }
297    }
298
299    /// Returns `true` if Unix domain socket process credential detection failed.
300    pub const fn has_process_credential_error(&self) -> bool {
301        match self {
302            Self::ProcessLike(identity) => identity.is_error(),
303            Self::SocketLike(_) => false,
304        }
305    }
306}
307
308impl From<SocketAddr> for ConnectionAddress {
309    fn from(value: SocketAddr) -> Self {
310        Self::SocketLike(value)
311    }
312}
313
314impl From<ProcessCredentials> for ConnectionAddress {
315    fn from(creds: ProcessCredentials) -> Self {
316        Self::ProcessLike(ProcessIdentity::Credentials(creds))
317    }
318}
319
320impl<'a> Connected<&'a Connection> for ConnectionAddress {
321    fn connect_info(target: &'a Connection) -> Self {
322        target.remote_addr()
323    }
324}
325
326/// A gRPC target address.
327///
328/// This represents the address of a gRPC server that can be connected to. `GrpcTargetAddress` exposes a `Display`
329/// implementation that emits the target address following the rules of the [gRPC Name
330/// Resolution][grpc_name_resolution_docs] documentation.
331///
332/// Only connection-oriented transports are supported: TCP and Unix domain sockets in SOCK_STREAM mode.
333///
334/// [grpc_name_resolution_docs]: https://github.com/grpc/grpc/blob/master/doc/naming.md
335pub enum GrpcTargetAddress {
336    Tcp(SocketAddr),
337    Unix(PathBuf),
338}
339
340impl GrpcTargetAddress {
341    /// Creates a new `GrpcTargetAddress` from the given `ListenAddress`.
342    ///
343    /// For TCP addresses, this method converts unspecified addresses (`0.0.0.0` or `::`) to localhost
344    /// (`127.0.0.1` or `::1`) to ensure the advertised address matches TLS certificates.
345    ///
346    /// Returns `None` if the listen address isn't a connection-oriented transport.
347    pub fn try_from_listen_addr(listen_address: &ListenAddress) -> Option<Self> {
348        match listen_address {
349            ListenAddress::Tcp(_) => {
350                // For TCP, convert 0.0.0.0 to 127.0.0.1 to match TLS certificate
351                listen_address.as_local_connect_addr().map(GrpcTargetAddress::Tcp)
352            }
353            ListenAddress::Unix(path) => Some(GrpcTargetAddress::Unix(path.clone())),
354            _ => None,
355        }
356    }
357}
358
359impl fmt::Display for GrpcTargetAddress {
360    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
361        match self {
362            GrpcTargetAddress::Tcp(addr) => write!(f, "{}", addr),
363            GrpcTargetAddress::Unix(path) => write!(f, "unix://{}", path.display()),
364        }
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    #[test]
373    fn test_as_local_connect_addr() {
374        let tcp_any_addr = ListenAddress::try_from("tcp://0.0.0.0:1234").unwrap();
375        assert_eq!(
376            tcp_any_addr.as_local_connect_addr(),
377            Some(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 1234)))
378        );
379
380        let tcp_localhost_addr = ListenAddress::try_from("tcp://127.0.0.1:2345").unwrap();
381        assert_eq!(
382            tcp_localhost_addr.as_local_connect_addr(),
383            Some(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 2345)))
384        );
385
386        let tcp_private_addr = ListenAddress::try_from("tcp://192.168.10.2:3456").unwrap();
387        assert_eq!(
388            tcp_private_addr.as_local_connect_addr(),
389            Some(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 10, 2), 3456)))
390        );
391
392        let udp_any_addr = ListenAddress::try_from("udp://0.0.0.0:4567").unwrap();
393        assert_eq!(
394            udp_any_addr.as_local_connect_addr(),
395            Some(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 4567)))
396        );
397
398        let udp_localhost_addr = ListenAddress::try_from("udp://127.0.0.1:5678").unwrap();
399        assert_eq!(
400            udp_localhost_addr.as_local_connect_addr(),
401            Some(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 5678)))
402        );
403
404        let udp_private_addr = ListenAddress::try_from("udp://192.168.10.2:6789").unwrap();
405        assert_eq!(
406            udp_private_addr.as_local_connect_addr(),
407            Some(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 10, 2), 6789)))
408        );
409    }
410}