Skip to main content

std\sys\net\connection\socket/
windows.rs

1#![unstable(issue = "none", feature = "windows_net")]
2
3use core::ffi::{c_int, c_long, c_ulong, c_ushort};
4
5use super::{getsockopt, setsockopt, socket_addr_from_c, socket_addr_to_c};
6use crate::io::{self, BorrowedBuf, BorrowedCursor, IoSlice, IoSliceMut, Read};
7use crate::net::{Shutdown, SocketAddr};
8use crate::os::windows::io::{
9    AsRawSocket, AsSocket, BorrowedSocket, FromRawSocket, IntoRawSocket, OwnedSocket, RawSocket,
10};
11use crate::sys::pal::winsock::last_error;
12use crate::sys::{AsInner, FromInner, IntoInner, c};
13use crate::time::Duration;
14use crate::{cmp, mem, ptr, sys};
15
16#[allow(non_camel_case_types)]
17pub type wrlen_t = i32;
18
19pub(super) mod netc {
20    //! BSD socket compatibility shim
21    //!
22    //! Some Windows API types are not quite what's expected by our cross-platform
23    //! net code. E.g. naming differences or different pointer types.
24
25    use core::ffi::{c_char, c_int, c_uint, c_ulong, c_ushort, c_void};
26
27    use crate::sys::c::{self, ADDRESS_FAMILY, ADDRINFOA, SOCKADDR, SOCKET};
28    // re-exports from Windows API bindings.
29    pub use crate::sys::c::{
30        ADDRESS_FAMILY as sa_family_t, ADDRINFOA as addrinfo, IP_ADD_MEMBERSHIP,
31        IP_DROP_MEMBERSHIP, IP_MULTICAST_LOOP, IP_MULTICAST_TTL, IP_TTL, IPPROTO_IP, IPPROTO_IPV6,
32        IPV6_ADD_MEMBERSHIP, IPV6_DROP_MEMBERSHIP, IPV6_MULTICAST_LOOP, IPV6_V6ONLY, SO_BROADCAST,
33        SO_RCVTIMEO, SO_SNDTIMEO, SOCK_DGRAM, SOCK_STREAM, SOCKADDR as sockaddr,
34        SOCKADDR_STORAGE as sockaddr_storage, SOL_SOCKET, bind, connect, freeaddrinfo, getpeername,
35        getsockname, getsockopt, listen, setsockopt,
36    };
37
38    #[allow(non_camel_case_types)]
39    pub type socklen_t = c_int;
40
41    pub const AF_INET: i32 = c::AF_INET as i32;
42    pub const AF_INET6: i32 = c::AF_INET6 as i32;
43
44    // The following two structs use a union in the generated bindings but
45    // our cross-platform code expects a normal field so it's redefined here.
46    // As a consequence, we also need to redefine other structs that use this struct.
47    #[repr(C)]
48    #[derive(Copy, Clone)]
49    pub struct in_addr {
50        pub s_addr: u32,
51    }
52
53    #[repr(C)]
54    #[derive(Copy, Clone)]
55    pub struct in6_addr {
56        pub s6_addr: [u8; 16],
57    }
58
59    #[repr(C)]
60    pub struct ip_mreq {
61        pub imr_multiaddr: in_addr,
62        pub imr_interface: in_addr,
63    }
64
65    #[repr(C)]
66    pub struct ipv6_mreq {
67        pub ipv6mr_multiaddr: in6_addr,
68        pub ipv6mr_interface: c_uint,
69    }
70
71    #[repr(C)]
72    #[derive(Copy, Clone)]
73    pub struct sockaddr_in {
74        pub sin_family: ADDRESS_FAMILY,
75        pub sin_port: c_ushort,
76        pub sin_addr: in_addr,
77        pub sin_zero: [c_char; 8],
78    }
79
80    #[repr(C)]
81    #[derive(Copy, Clone)]
82    pub struct sockaddr_in6 {
83        pub sin6_family: ADDRESS_FAMILY,
84        pub sin6_port: c_ushort,
85        pub sin6_flowinfo: c_ulong,
86        pub sin6_addr: in6_addr,
87        pub sin6_scope_id: c_ulong,
88    }
89
90    pub unsafe fn send(socket: SOCKET, buf: *const c_void, len: c_int, flags: c_int) -> c_int {
91        unsafe { c::send(socket, buf.cast::<u8>(), len, flags) }
92    }
93    pub unsafe fn sendto(
94        socket: SOCKET,
95        buf: *const c_void,
96        len: c_int,
97        flags: c_int,
98        addr: *const SOCKADDR,
99        addrlen: c_int,
100    ) -> c_int {
101        unsafe { c::sendto(socket, buf.cast::<u8>(), len, flags, addr, addrlen) }
102    }
103    pub unsafe fn getaddrinfo(
104        node: *const c_char,
105        service: *const c_char,
106        hints: *const ADDRINFOA,
107        res: *mut *mut ADDRINFOA,
108    ) -> c_int {
109        unsafe { c::getaddrinfo(node.cast::<u8>(), service.cast::<u8>(), hints, res) }
110    }
111}
112
113pub use crate::sys::pal::winsock::{cvt, cvt_gai, cvt_r, startup as init};
114
115#[expect(missing_debug_implementations)]
116pub struct Socket(OwnedSocket);
117
118impl Socket {
119    pub fn new(family: c_int, ty: c_int) -> io::Result<Socket> {
120        let socket = unsafe {
121            c::WSASocketW(
122                family,
123                ty,
124                0,
125                ptr::null_mut(),
126                0,
127                c::WSA_FLAG_OVERLAPPED | c::WSA_FLAG_NO_HANDLE_INHERIT,
128            )
129        };
130
131        if socket != c::INVALID_SOCKET {
132            unsafe { Ok(Self::from_raw(socket)) }
133        } else {
134            let error = unsafe { c::WSAGetLastError() };
135
136            if error != c::WSAEPROTOTYPE && error != c::WSAEINVAL {
137                return Err(io::Error::from_raw_os_error(error));
138            }
139
140            let socket =
141                unsafe { c::WSASocketW(family, ty, 0, ptr::null_mut(), 0, c::WSA_FLAG_OVERLAPPED) };
142
143            if socket == c::INVALID_SOCKET {
144                return Err(last_error());
145            }
146
147            unsafe {
148                let socket = Self::from_raw(socket);
149                socket.0.set_no_inherit()?;
150                Ok(socket)
151            }
152        }
153    }
154
155    pub fn connect(&self, addr: &SocketAddr) -> io::Result<()> {
156        let (addr, len) = socket_addr_to_c(addr);
157        let result = unsafe { c::connect(self.as_raw(), addr.as_ptr(), len) };
158        cvt(result).map(drop)
159    }
160
161    pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
162        self.set_nonblocking(true)?;
163        let result = self.connect(addr);
164        self.set_nonblocking(false)?;
165
166        match result {
167            Err(ref error) if error.kind() == io::ErrorKind::WouldBlock => {
168                if timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 {
169                    return Err(io::Error::ZERO_TIMEOUT);
170                }
171
172                let mut timeout = c::TIMEVAL {
173                    tv_sec: cmp::min(timeout.as_secs(), c_long::MAX as u64) as c_long,
174                    tv_usec: timeout.subsec_micros() as c_long,
175                };
176
177                if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
178                    timeout.tv_usec = 1;
179                }
180
181                let fds = {
182                    let mut fds = unsafe { mem::zeroed::<c::FD_SET>() };
183                    fds.fd_count = 1;
184                    fds.fd_array[0] = self.as_raw();
185                    fds
186                };
187
188                let mut writefds = fds;
189                let mut errorfds = fds;
190
191                let count = {
192                    let result = unsafe {
193                        c::select(1, ptr::null_mut(), &mut writefds, &mut errorfds, &timeout)
194                    };
195                    cvt(result)?
196                };
197
198                match count {
199                    0 => Err(io::const_error!(io::ErrorKind::TimedOut, "connection timed out")),
200                    _ => {
201                        if writefds.fd_count != 1 {
202                            if let Some(e) = self.take_error()? {
203                                return Err(e);
204                            }
205                        }
206
207                        Ok(())
208                    }
209                }
210            }
211            _ => result,
212        }
213    }
214
215    pub fn accept(&self, storage: *mut c::SOCKADDR, len: *mut c_int) -> io::Result<Socket> {
216        let socket = unsafe { c::accept(self.as_raw(), storage, len) };
217
218        match socket {
219            c::INVALID_SOCKET => Err(last_error()),
220            _ => unsafe { Ok(Self::from_raw(socket)) },
221        }
222    }
223
224    pub fn duplicate(&self) -> io::Result<Socket> {
225        Ok(Self(self.0.try_clone()?))
226    }
227
228    fn recv_with_flags(&self, mut buf: BorrowedCursor<'_>, flags: c_int) -> io::Result<()> {
229        // On unix when a socket is shut down all further reads return 0, so we
230        // do the same on windows to map a shut down socket to returning EOF.
231        let length = cmp::min(buf.capacity(), i32::MAX as usize) as i32;
232        let result =
233            unsafe { c::recv(self.as_raw(), buf.as_mut().as_mut_ptr() as *mut _, length, flags) };
234
235        match result {
236            c::SOCKET_ERROR => {
237                let error = unsafe { c::WSAGetLastError() };
238
239                if error == c::WSAESHUTDOWN {
240                    Ok(())
241                } else {
242                    Err(io::Error::from_raw_os_error(error))
243                }
244            }
245            _ => {
246                unsafe { buf.advance(result as usize) };
247                Ok(())
248            }
249        }
250    }
251
252    pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
253        let mut buf = BorrowedBuf::from(buf);
254        self.recv_with_flags(buf.unfilled(), 0)?;
255        Ok(buf.len())
256    }
257
258    pub fn read_buf(&self, buf: BorrowedCursor<'_>) -> io::Result<()> {
259        self.recv_with_flags(buf, 0)
260    }
261
262    pub fn read_vectored(&self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
263        // On unix when a socket is shut down all further reads return 0, so we
264        // do the same on windows to map a shut down socket to returning EOF.
265        let length = cmp::min(bufs.len(), u32::MAX as usize) as u32;
266        let mut nread = 0;
267        let mut flags = 0;
268        let result = unsafe {
269            c::WSARecv(
270                self.as_raw(),
271                bufs.as_mut_ptr() as *mut c::WSABUF,
272                length,
273                &mut nread,
274                &mut flags,
275                ptr::null_mut(),
276                None,
277            )
278        };
279
280        match result {
281            0 => Ok(nread as usize),
282            _ => {
283                let error = unsafe { c::WSAGetLastError() };
284
285                if error == c::WSAESHUTDOWN {
286                    Ok(0)
287                } else {
288                    Err(io::Error::from_raw_os_error(error))
289                }
290            }
291        }
292    }
293
294    #[inline]
295    pub fn is_read_vectored(&self) -> bool {
296        true
297    }
298
299    pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
300        let mut buf = BorrowedBuf::from(buf);
301        self.recv_with_flags(buf.unfilled(), c::MSG_PEEK)?;
302        Ok(buf.len())
303    }
304
305    fn recv_from_with_flags(
306        &self,
307        buf: &mut [u8],
308        flags: c_int,
309    ) -> io::Result<(usize, SocketAddr)> {
310        let mut storage = unsafe { mem::zeroed::<c::SOCKADDR_STORAGE>() };
311        let mut addrlen = size_of_val(&storage) as netc::socklen_t;
312        let length = cmp::min(buf.len(), <wrlen_t>::MAX as usize) as wrlen_t;
313
314        // On unix when a socket is shut down all further reads return 0, so we
315        // do the same on windows to map a shut down socket to returning EOF.
316        let result = unsafe {
317            c::recvfrom(
318                self.as_raw(),
319                buf.as_mut_ptr() as *mut _,
320                length,
321                flags,
322                (&raw mut storage) as *mut _,
323                &mut addrlen,
324            )
325        };
326
327        match result {
328            c::SOCKET_ERROR => {
329                let error = unsafe { c::WSAGetLastError() };
330
331                if error == c::WSAESHUTDOWN {
332                    Ok((0, unsafe { socket_addr_from_c(&storage, addrlen as usize)? }))
333                } else {
334                    Err(io::Error::from_raw_os_error(error))
335                }
336            }
337            _ => Ok((result as usize, unsafe { socket_addr_from_c(&storage, addrlen as usize)? })),
338        }
339    }
340
341    pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
342        self.recv_from_with_flags(buf, 0)
343    }
344
345    pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
346        self.recv_from_with_flags(buf, c::MSG_PEEK)
347    }
348
349    pub fn write_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
350        let length = cmp::min(bufs.len(), u32::MAX as usize) as u32;
351        let mut nwritten = 0;
352        let result = unsafe {
353            c::WSASend(
354                self.as_raw(),
355                bufs.as_ptr() as *const c::WSABUF as *mut _,
356                length,
357                &mut nwritten,
358                0,
359                ptr::null_mut(),
360                None,
361            )
362        };
363        cvt(result).map(|_| nwritten as usize)
364    }
365
366    #[inline]
367    pub fn is_write_vectored(&self) -> bool {
368        true
369    }
370
371    pub fn set_timeout(&self, dur: Option<Duration>, kind: c_int) -> io::Result<()> {
372        let timeout = match dur {
373            Some(dur) => {
374                let timeout = sys::dur2timeout(dur);
375                if timeout == 0 {
376                    return Err(io::Error::ZERO_TIMEOUT);
377                }
378                timeout
379            }
380            None => 0,
381        };
382        unsafe { setsockopt(self, c::SOL_SOCKET, kind, timeout) }
383    }
384
385    pub fn timeout(&self, kind: c_int) -> io::Result<Option<Duration>> {
386        let raw: u32 = unsafe { getsockopt(self, c::SOL_SOCKET, kind)? };
387        if raw == 0 {
388            Ok(None)
389        } else {
390            let secs = raw / 1000;
391            let nsec = (raw % 1000) * 1000000;
392            Ok(Some(Duration::new(secs as u64, nsec as u32)))
393        }
394    }
395
396    pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
397        let how = match how {
398            Shutdown::Write => c::SD_SEND,
399            Shutdown::Read => c::SD_RECEIVE,
400            Shutdown::Both => c::SD_BOTH,
401        };
402        let result = unsafe { c::shutdown(self.as_raw(), how) };
403        cvt(result).map(drop)
404    }
405
406    pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
407        let mut nonblocking = nonblocking as c_ulong;
408        let result =
409            unsafe { c::ioctlsocket(self.as_raw(), c::FIONBIO as c_int, &mut nonblocking) };
410        cvt(result).map(drop)
411    }
412
413    pub fn set_linger(&self, linger: Option<Duration>) -> io::Result<()> {
414        let linger = c::LINGER {
415            l_onoff: linger.is_some() as c_ushort,
416            l_linger: cmp::min(linger.unwrap_or_default().as_secs(), c_ushort::MAX as u64)
417                as c_ushort,
418        };
419
420        unsafe { setsockopt(self, c::SOL_SOCKET, c::SO_LINGER, linger) }
421    }
422
423    pub fn linger(&self) -> io::Result<Option<Duration>> {
424        let val: c::LINGER = unsafe { getsockopt(self, c::SOL_SOCKET, c::SO_LINGER)? };
425
426        Ok((val.l_onoff != 0).then(|| Duration::from_secs(val.l_linger as u64)))
427    }
428
429    pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
430        unsafe { setsockopt(self, c::IPPROTO_TCP, c::TCP_NODELAY, nodelay as c::BOOL) }
431    }
432
433    pub fn nodelay(&self) -> io::Result<bool> {
434        let raw: c::BOOL = unsafe { getsockopt(self, c::IPPROTO_TCP, c::TCP_NODELAY)? };
435        Ok(raw != 0)
436    }
437
438    pub fn take_error(&self) -> io::Result<Option<io::Error>> {
439        let raw: c_int = unsafe { getsockopt(self, c::SOL_SOCKET, c::SO_ERROR)? };
440        if raw == 0 { Ok(None) } else { Ok(Some(io::Error::from_raw_os_error(raw as i32))) }
441    }
442
443    pub fn as_raw(&self) -> c::SOCKET {
444        debug_assert_eq!(size_of::<c::SOCKET>(), size_of::<RawSocket>());
445        debug_assert_eq!(align_of::<c::SOCKET>(), align_of::<RawSocket>());
446        self.as_inner().as_raw_socket() as c::SOCKET
447    }
448    pub unsafe fn from_raw(raw: c::SOCKET) -> Self {
449        debug_assert_eq!(size_of::<c::SOCKET>(), size_of::<RawSocket>());
450        debug_assert_eq!(align_of::<c::SOCKET>(), align_of::<RawSocket>());
451        unsafe { Self::from_raw_socket(raw as RawSocket) }
452    }
453}
454
455impl<'a> Read for &'a Socket {
456    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
457        (**self).read(buf)
458    }
459}
460
461impl AsInner<OwnedSocket> for Socket {
462    #[inline]
463    fn as_inner(&self) -> &OwnedSocket {
464        &self.0
465    }
466}
467
468impl FromInner<OwnedSocket> for Socket {
469    fn from_inner(sock: OwnedSocket) -> Socket {
470        Socket(sock)
471    }
472}
473
474impl IntoInner<OwnedSocket> for Socket {
475    fn into_inner(self) -> OwnedSocket {
476        self.0
477    }
478}
479
480impl AsSocket for Socket {
481    fn as_socket(&self) -> BorrowedSocket<'_> {
482        self.0.as_socket()
483    }
484}
485
486impl AsRawSocket for Socket {
487    fn as_raw_socket(&self) -> RawSocket {
488        self.0.as_raw_socket()
489    }
490}
491
492impl IntoRawSocket for Socket {
493    fn into_raw_socket(self) -> RawSocket {
494        self.0.into_raw_socket()
495    }
496}
497
498impl FromRawSocket for Socket {
499    unsafe fn from_raw_socket(raw_socket: RawSocket) -> Self {
500        unsafe { Self(FromRawSocket::from_raw_socket(raw_socket)) }
501    }
502}