Skip to main content

saluki_io/net/unix/
linux.rs

1use std::{io, mem, os::fd::AsRawFd};
2
3use bytes::BufMut;
4use socket2::{Domain, MaybeUninitSlice, MsgHdrMut, Protocol, SockAddr, SockAddrStorage, SockRef, Socket, Type};
5use tokio::net::UnixDatagram;
6
7use super::ancillary::{ControlMessage, SocketCredentialsAncillaryData};
8use crate::net::addr::{ConnectionAddress, ProcessCredentials, ProcessCredentialsError, ProcessIdentity};
9
10/// Enables the `SO_PASSCRED` option on the given socket.
11///
12/// ## Errors
13///
14/// If the underlying system call fails, an error is returned.
15pub fn enable_uds_socket_credentials<'sock, S>(socket: &'sock S) -> io::Result<()>
16where
17    SockRef<'sock>: From<&'sock S>,
18{
19    let sock_ref = SockRef::from(socket);
20    sock_ref.set_passcred(true)
21}
22
23pub(super) fn uds_recvmsg<'sock, S, B: BufMut>(socket: &'sock S, buf: &mut B) -> io::Result<(usize, ConnectionAddress)>
24where
25    SockRef<'sock>: From<&'sock S>,
26{
27    let sock_ref = SockRef::from(socket);
28
29    // Create the message header struct that will be populated by the call to `recvmsg`, which includes the peer
30    // address, message data, and any ancillary (out-of-band) data.
31    //
32    // SAFETY: We're allocating `sockaddr_storage`, which is always large enough to hold any address family's socket
33    // address structure.
34    let sock_storage = SockAddrStorage::zeroed();
35    let sock_storage_len = sock_storage.size_of();
36    let mut sock_addr = unsafe { SockAddr::new(sock_storage, sock_storage_len) };
37
38    let mut ancillary_data = SocketCredentialsAncillaryData::new();
39
40    let data_buf = unsafe { MaybeUninitSlice::new(buf.chunk_mut().as_uninit_slice_mut()) };
41    let mut data_bufs = [data_buf];
42
43    let mut msg_hdr = MsgHdrMut::new()
44        .with_addr(&mut sock_addr)
45        .with_buffers(&mut data_bufs)
46        .with_control(ancillary_data.as_mut_uninit());
47
48    let n = sock_ref.recvmsg(&mut msg_hdr, libc::MSG_CMSG_CLOEXEC)?;
49
50    // If we got any socket credentials back, parse them.
51    let control_len = msg_hdr.control_len();
52
53    let process_identity = if control_len > 0 {
54        unsafe {
55            ancillary_data.set_len(control_len);
56
57            match ancillary_data
58                .messages()
59                .map(|m| match m {
60                    ControlMessage::Credentials(creds) => creds,
61                })
62                .next()
63            {
64                Some(creds) if creds.pid == 0 => ProcessIdentity::Error(ProcessCredentialsError::ZeroPid),
65                Some(creds) => ProcessIdentity::Credentials(ProcessCredentials {
66                    pid: creds.pid,
67                    uid: creds.uid,
68                    gid: creds.gid,
69                }),
70                None => ProcessIdentity::Error(ProcessCredentialsError::InvalidCredentials),
71            }
72        }
73    } else {
74        ProcessIdentity::Unavailable
75    };
76
77    let conn_addr = ConnectionAddress::ProcessLike(process_identity);
78
79    // Finally, update our buffer to reflect the bytes we've read.
80    unsafe {
81        buf.advance_mut(n);
82    }
83
84    Ok((n, conn_addr))
85}
86
87/// Returns `true` if `SO_REUSEPORT` is supported for UDP sockets on the current platform.
88pub fn socket_reuseport_supported() -> bool {
89    let socket = match Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) {
90        Ok(socket) => socket,
91        Err(_) => return false,
92    };
93
94    match socket.set_reuse_port(true) {
95        Ok(()) => true,
96        Err(_) => false,
97    }
98}
99
100/// Sends data to the Unix domain socket.
101///
102/// This function is specifically for connected Unix domain sockets in datagram mode (that's, SOCK_DGRAM), which are
103/// represented via `UnixDatagram` in `tokio`.
104///
105/// The payload is sent with an `SCM_CREDENTIALS` ancillary block using the provided `ProcessCredentials`.
106///
107/// Linux permits a sender to use its own PID, UID, and GID normally. Sending a forged PID requires `CAP_SYS_ADMIN`,
108/// sending a forged UID requires `CAP_SETUID`, and sending a forged GID requires `CAP_SETGID`.
109///
110/// ## Errors
111///
112/// If socket readiness or the underlying system call fails, an error is returned.
113pub async fn uds_sendmsg_with_creds(
114    socket: &UnixDatagram, payload: &[u8], credentials: &ProcessCredentials,
115) -> io::Result<usize> {
116    let creds = libc::ucred {
117        pid: credentials.pid as libc::pid_t,
118        uid: credentials.uid,
119        gid: credentials.gid,
120    };
121
122    socket
123        .async_io(tokio::io::Interest::WRITABLE, || {
124            sendmsg_with_ucred(socket.as_raw_fd(), payload, &creds)
125        })
126        .await
127}
128
129/// Synchronously writes one payload with the given credentials to the raw file descriptor.
130///
131/// Constructs a `cmsghdr` header followed by the `ucred` body in a single control buffer, then invokes `sendmsg`.
132fn sendmsg_with_ucred(fd: libc::c_int, payload: &[u8], creds: &libc::ucred) -> io::Result<usize> {
133    // SAFETY: `CMSG_SPACE` is a const expression on `size_of::<ucred>()`; the call is safe and returns the byte count
134    // needed to hold one aligned cmsghdr plus a ucred payload.
135    let control_len = unsafe { libc::CMSG_SPACE(mem::size_of::<libc::ucred>() as u32) as usize };
136
137    // `CMSG_SPACE` gives us the padded byte length, but the backing storage also has to be aligned for `cmsghdr` and
138    // `ucred` because the CMSG_* macros return typed pointers into it.
139    let control_words = control_len.div_ceil(mem::size_of::<usize>());
140    let mut control_buf = vec![0usize; control_words];
141
142    // SAFETY: we construct a `msghdr` pointing at the payload and the control buffer, then walk the control buffer
143    // with the libc CMSG_FIRSTHDR / CMSG_DATA macros to write the cmsghdr header and ucred body. Pointers all
144    // reference live local memory; lifetimes don't escape the call.
145    let n = unsafe {
146        // We use `IoSlice`-style iovec entries pointing at the payload.
147        let mut iov = libc::iovec {
148            iov_base: payload.as_ptr() as *mut libc::c_void,
149            iov_len: payload.len(),
150        };
151
152        let mut msg: libc::msghdr = mem::zeroed();
153        msg.msg_iov = &mut iov;
154        msg.msg_iovlen = 1;
155        msg.msg_control = control_buf.as_mut_ptr().cast::<libc::c_void>();
156        msg.msg_controllen = control_len as _;
157
158        // Populate the cmsghdr at the start of the control buffer.
159        let cmsg = libc::CMSG_FIRSTHDR(&msg);
160        if cmsg.is_null() {
161            return Err(io::Error::other("failed to obtain cmsghdr from control buffer"));
162        }
163        (*cmsg).cmsg_level = libc::SOL_SOCKET;
164        (*cmsg).cmsg_type = libc::SCM_CREDENTIALS;
165        (*cmsg).cmsg_len = libc::CMSG_LEN(mem::size_of::<libc::ucred>() as u32) as _;
166
167        // Copy the ucred body into the cmsg data region.
168        let data_ptr = libc::CMSG_DATA(cmsg) as *mut libc::ucred;
169        std::ptr::write(data_ptr, *creds);
170
171        // Send.
172        libc::sendmsg(fd, &msg, libc::MSG_NOSIGNAL)
173    };
174
175    if n < 0 {
176        Err(io::Error::last_os_error())
177    } else {
178        Ok(n as usize)
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use std::os::fd::{AsRawFd, FromRawFd, OwnedFd};
185
186    use super::*;
187
188    #[test]
189    fn sendmsg_with_current_credentials_round_trips_payload() {
190        // Construct a socketpair, send a payload with our own creds, read it back from the receiver, assert payload
191        // bytes match. This exercises the sendmsg construction path without requiring elevated capabilities.
192        let (sender, receiver) = unsafe {
193            let mut fds: [libc::c_int; 2] = [-1, -1];
194            let rc = libc::socketpair(libc::AF_UNIX, libc::SOCK_DGRAM, 0, fds.as_mut_ptr());
195            assert_eq!(rc, 0, "socketpair failed: {}", io::Error::last_os_error());
196            (OwnedFd::from_raw_fd(fds[0]), OwnedFd::from_raw_fd(fds[1]))
197        };
198
199        let creds = libc::ucred {
200            pid: std::process::id() as libc::pid_t,
201            uid: unsafe { libc::getuid() },
202            gid: unsafe { libc::getgid() },
203        };
204        let payload = b"uds-sendmsg-test-payload";
205        let written = sendmsg_with_ucred(sender.as_raw_fd(), payload, &creds).expect("send should succeed");
206        assert_eq!(written, payload.len());
207
208        // Read back to confirm the receiver got the bytes.
209        let mut buf = [0u8; 64];
210        let read = unsafe {
211            libc::recv(
212                receiver.as_raw_fd(),
213                buf.as_mut_ptr() as *mut libc::c_void,
214                buf.len(),
215                0,
216            )
217        };
218        assert!(read > 0, "recv failed: {}", io::Error::last_os_error());
219        assert_eq!(&buf[..read as usize], payload);
220    }
221}