saluki_io/net/unix/
linux.rs1use std::{io, mem, os::fd::AsRawFd};
2
3use bytes::BufMut;
4use socket2::{Domain, MaybeUninitSlice, MsgHdrMut, Protocol, SockAddr, SockAddrStorage, SockRef, Socket, Type};
5use tokio::net::UnixDatagram;
6
7use super::ancillary::{ControlMessage, SocketCredentialsAncillaryData};
8use crate::net::addr::{ConnectionAddress, ProcessCredentials, ProcessCredentialsError, ProcessIdentity};
9
10pub fn enable_uds_socket_credentials<'sock, S>(socket: &'sock S) -> io::Result<()>
16where
17 SockRef<'sock>: From<&'sock S>,
18{
19 let sock_ref = SockRef::from(socket);
20 sock_ref.set_passcred(true)
21}
22
23pub(super) fn uds_recvmsg<'sock, S, B: BufMut>(socket: &'sock S, buf: &mut B) -> io::Result<(usize, ConnectionAddress)>
24where
25 SockRef<'sock>: From<&'sock S>,
26{
27 let sock_ref = SockRef::from(socket);
28
29 let sock_storage = SockAddrStorage::zeroed();
35 let sock_storage_len = sock_storage.size_of();
36 let mut sock_addr = unsafe { SockAddr::new(sock_storage, sock_storage_len) };
37
38 let mut ancillary_data = SocketCredentialsAncillaryData::new();
39
40 let data_buf = unsafe { MaybeUninitSlice::new(buf.chunk_mut().as_uninit_slice_mut()) };
41 let mut data_bufs = [data_buf];
42
43 let mut msg_hdr = MsgHdrMut::new()
44 .with_addr(&mut sock_addr)
45 .with_buffers(&mut data_bufs)
46 .with_control(ancillary_data.as_mut_uninit());
47
48 let n = sock_ref.recvmsg(&mut msg_hdr, libc::MSG_CMSG_CLOEXEC)?;
49
50 let control_len = msg_hdr.control_len();
52
53 let process_identity = if control_len > 0 {
54 unsafe {
55 ancillary_data.set_len(control_len);
56
57 match ancillary_data
58 .messages()
59 .map(|m| match m {
60 ControlMessage::Credentials(creds) => creds,
61 })
62 .next()
63 {
64 Some(creds) if creds.pid == 0 => ProcessIdentity::Error(ProcessCredentialsError::ZeroPid),
65 Some(creds) => ProcessIdentity::Credentials(ProcessCredentials {
66 pid: creds.pid,
67 uid: creds.uid,
68 gid: creds.gid,
69 }),
70 None => ProcessIdentity::Error(ProcessCredentialsError::InvalidCredentials),
71 }
72 }
73 } else {
74 ProcessIdentity::Unavailable
75 };
76
77 let conn_addr = ConnectionAddress::ProcessLike(process_identity);
78
79 unsafe {
81 buf.advance_mut(n);
82 }
83
84 Ok((n, conn_addr))
85}
86
87pub fn socket_reuseport_supported() -> bool {
89 let socket = match Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) {
90 Ok(socket) => socket,
91 Err(_) => return false,
92 };
93
94 match socket.set_reuse_port(true) {
95 Ok(()) => true,
96 Err(_) => false,
97 }
98}
99
100pub async fn uds_sendmsg_with_creds(
114 socket: &UnixDatagram, payload: &[u8], credentials: &ProcessCredentials,
115) -> io::Result<usize> {
116 let creds = libc::ucred {
117 pid: credentials.pid as libc::pid_t,
118 uid: credentials.uid,
119 gid: credentials.gid,
120 };
121
122 socket
123 .async_io(tokio::io::Interest::WRITABLE, || {
124 sendmsg_with_ucred(socket.as_raw_fd(), payload, &creds)
125 })
126 .await
127}
128
129fn sendmsg_with_ucred(fd: libc::c_int, payload: &[u8], creds: &libc::ucred) -> io::Result<usize> {
133 let control_len = unsafe { libc::CMSG_SPACE(mem::size_of::<libc::ucred>() as u32) as usize };
136
137 let control_words = control_len.div_ceil(mem::size_of::<usize>());
140 let mut control_buf = vec![0usize; control_words];
141
142 let n = unsafe {
146 let mut iov = libc::iovec {
148 iov_base: payload.as_ptr() as *mut libc::c_void,
149 iov_len: payload.len(),
150 };
151
152 let mut msg: libc::msghdr = mem::zeroed();
153 msg.msg_iov = &mut iov;
154 msg.msg_iovlen = 1;
155 msg.msg_control = control_buf.as_mut_ptr().cast::<libc::c_void>();
156 msg.msg_controllen = control_len as _;
157
158 let cmsg = libc::CMSG_FIRSTHDR(&msg);
160 if cmsg.is_null() {
161 return Err(io::Error::other("failed to obtain cmsghdr from control buffer"));
162 }
163 (*cmsg).cmsg_level = libc::SOL_SOCKET;
164 (*cmsg).cmsg_type = libc::SCM_CREDENTIALS;
165 (*cmsg).cmsg_len = libc::CMSG_LEN(mem::size_of::<libc::ucred>() as u32) as _;
166
167 let data_ptr = libc::CMSG_DATA(cmsg) as *mut libc::ucred;
169 std::ptr::write(data_ptr, *creds);
170
171 libc::sendmsg(fd, &msg, libc::MSG_NOSIGNAL)
173 };
174
175 if n < 0 {
176 Err(io::Error::last_os_error())
177 } else {
178 Ok(n as usize)
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use std::os::fd::{AsRawFd, FromRawFd, OwnedFd};
185
186 use super::*;
187
188 #[test]
189 fn sendmsg_with_current_credentials_round_trips_payload() {
190 let (sender, receiver) = unsafe {
193 let mut fds: [libc::c_int; 2] = [-1, -1];
194 let rc = libc::socketpair(libc::AF_UNIX, libc::SOCK_DGRAM, 0, fds.as_mut_ptr());
195 assert_eq!(rc, 0, "socketpair failed: {}", io::Error::last_os_error());
196 (OwnedFd::from_raw_fd(fds[0]), OwnedFd::from_raw_fd(fds[1]))
197 };
198
199 let creds = libc::ucred {
200 pid: std::process::id() as libc::pid_t,
201 uid: unsafe { libc::getuid() },
202 gid: unsafe { libc::getgid() },
203 };
204 let payload = b"uds-sendmsg-test-payload";
205 let written = sendmsg_with_ucred(sender.as_raw_fd(), payload, &creds).expect("send should succeed");
206 assert_eq!(written, payload.len());
207
208 let mut buf = [0u8; 64];
210 let read = unsafe {
211 libc::recv(
212 receiver.as_raw_fd(),
213 buf.as_mut_ptr() as *mut libc::c_void,
214 buf.len(),
215 0,
216 )
217 };
218 assert!(read > 0, "recv failed: {}", io::Error::last_os_error());
219 assert_eq!(&buf[..read as usize], payload);
220 }
221}