diff --git a/Cargo.toml b/Cargo.toml index df52f86..061468e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,11 +1,11 @@ [package] name = "unix_socket" -version = "0.4.2" +version = "0.4.3" authors = ["Steven Fackler "] license = "MIT" description = "Unix domain socket bindings" repository = "https://github.com/sfackler/rust-unix-socket" -documentation = "https://sfackler.github.io/rust-unix-socket/doc/v0.4.2/unix_socket" +documentation = "https://sfackler.github.io/rust-unix-socket/doc/v0.4.3/unix_socket" readme = "README.md" keywords = ["posix", "unix", "socket", "domain"] diff --git a/src/lib.rs b/src/lib.rs index d170dc3..0848a64 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,6 @@ //! Support for Unix domain socket clients and servers. #![warn(missing_docs)] -#![doc(html_root_url="https://sfackler.github.io/rust-unix-socket/doc/v0.4.2")] +#![doc(html_root_url="https://sfackler.github.io/rust-unix-socket/doc/v0.4.3")] #![cfg_attr(feature = "socket_timeout", feature(duration))] #![cfg_attr(all(test, feature = "socket_timeout"), feature(duration_span))] @@ -47,6 +47,22 @@ fn sun_path_offset() -> usize { } } +fn cvt(v: libc::c_int) -> io::Result { + if v < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(v) + } +} + +fn cvt_s(v: libc::ssize_t) -> io::Result { + if v < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(v) + } +} + struct Inner(RawFd); impl Drop for Inner { @@ -58,31 +74,97 @@ impl Drop for Inner { } impl Inner { - unsafe fn new() -> io::Result { - let fd = libc::socket(libc::AF_UNIX, libc::SOCK_STREAM, 0); - if fd < 0 { - Err(io::Error::last_os_error()) - } else { - Ok(Inner(fd)) + fn new(kind: libc::c_int) -> io::Result { + unsafe { + cvt(libc::socket(libc::AF_UNIX, kind, 0)).map(Inner) } } - unsafe fn new_pair() -> io::Result<(Inner, Inner)> { - let mut fds = [0, 0]; - let res = socketpair(libc::AF_UNIX, libc::SOCK_STREAM, 0, &mut fds); - if res < 0 { - return Err(io::Error::last_os_error()); + fn new_pair() -> io::Result<(Inner, Inner)> { + unsafe { + let mut fds = [0, 0]; + try!(cvt(socketpair(libc::AF_UNIX, libc::SOCK_STREAM, 0, &mut fds))); + Ok((Inner(fds[0]), Inner(fds[1]))) } - debug_assert_eq!(res, 0); - Ok((Inner(fds[0]), Inner(fds[1]))) } fn try_clone(&self) -> io::Result { - let fd = unsafe { libc::dup(self.0) }; - if fd < 0 { - Err(io::Error::last_os_error()) + unsafe { + cvt(libc::dup(self.0)).map(Inner) + } + } + + fn shutdown(&self, how: Shutdown) -> io::Result<()> { + let how = match how { + Shutdown::Read => libc::SHUT_RD, + Shutdown::Write => libc::SHUT_WR, + Shutdown::Both => libc::SHUT_RDWR, + }; + + unsafe { + cvt(libc::shutdown(self.0, how)).map(|_| ()) + } + } + + #[cfg(feature = "socket_timeout")] + fn timeout(&self, kind: libc::c_int) -> io::Result> { + let timeout = unsafe { + let mut timeout: libc::timeval = mem::zeroed(); + let mut size = mem::size_of::() as libc::socklen_t; + try!(cvt(getsockopt(self.0, + libc::SOL_SOCKET, + kind, + &mut timeout as *mut _ as *mut _, + &mut size as *mut _ as *mut _))); + timeout + }; + + if timeout.tv_sec == 0 && timeout.tv_usec == 0 { + Ok(None) } else { - Ok(Inner(fd)) + Ok(Some(std::time::Duration::new(timeout.tv_sec as u64, + (timeout.tv_usec as u32) * 1000))) + } + } + + #[cfg(feature = "socket_timeout")] + fn set_timeout(&self, dur: Option, kind: libc::c_int) -> io::Result<()> { + let timeout = match dur { + Some(dur) => { + if dur.secs() == 0 && dur.extra_nanos() == 0 { + return Err(io::Error::new(io::ErrorKind::InvalidInput, + "cannot set a 0 duration timeout")); + } + + let secs = if dur.secs() > libc::time_t::max_value() as u64 { + libc::time_t::max_value() + } else { + dur.secs() as libc::time_t + }; + let mut timeout = libc::timeval { + tv_sec: secs, + tv_usec: (dur.extra_nanos() / 1000) as libc::suseconds_t, + }; + if timeout.tv_sec == 0 && timeout.tv_usec == 0 { + timeout.tv_usec = 1; + } + timeout + } + None => { + libc::timeval { + tv_sec: 0, + tv_usec: 0, + } + } + }; + + unsafe { + cvt(libc::setsockopt(self.0, + libc::SOL_SOCKET, + kind, + &timeout as *const _ as *const _, + mem::size_of::() as libc::socklen_t)) + .map(|_| ()) } } } @@ -148,19 +230,12 @@ impl Clone for SocketAddr { } impl SocketAddr { - fn new(fd: RawFd, - f: unsafe extern "system" fn(libc::c_int, - *mut libc::sockaddr, - *mut libc::socklen_t) -> libc::c_int) - -> io::Result { + fn new(f: F) -> io::Result + where F: FnOnce(*mut libc::sockaddr, *mut libc::socklen_t) -> libc::c_int { unsafe { let mut addr: libc::sockaddr_un = mem::zeroed(); let mut len = mem::size_of::() as libc::socklen_t; - let ret = f(fd, &mut addr as *mut _ as *mut _, &mut len); - - if ret != 0 { - return Err(io::Error::last_os_error()); - } + try!(cvt(f(&mut addr as *mut _ as *mut _, &mut len))); if addr.sun_family != libc::AF_UNIX as libc::sa_family_t { return Err(io::Error::new(io::ErrorKind::InvalidInput, @@ -212,7 +287,7 @@ impl<'a> fmt::Display for AsciiEscaped<'a> { } } -/// A stream which communicates over a Unix domain socket. +/// A Unix stream socket. /// /// # Examples /// @@ -254,7 +329,7 @@ impl UnixStream { /// corresponding to a path on the filesystem. pub fn connect>(path: P) -> io::Result { unsafe { - let inner = try!(Inner::new()); + let inner = try!(Inner::new(libc::SOCK_STREAM)); let (addr, len) = try!(sockaddr_un(path)); let ret = libc::connect(inner.0, &addr as *const _ as *const _, len); @@ -272,10 +347,8 @@ impl UnixStream { /// /// Returns two `UnixStream`s which are connected to each other. pub fn unnamed() -> io::Result<(UnixStream, UnixStream)> { - unsafe { - let (i1, i2) = try!(Inner::new_pair()); - Ok((UnixStream { inner: i1 }, UnixStream { inner: i2 })) - } + let (i1, i2) = try!(Inner::new_pair()); + Ok((UnixStream { inner: i1 }, UnixStream { inner: i2 })) } /// Create a new independently owned handle to the underlying socket. @@ -292,12 +365,12 @@ impl UnixStream { /// Returns the socket address of the local half of this connection. pub fn local_addr(&self) -> io::Result { - SocketAddr::new(self.inner.0, libc::getsockname) + SocketAddr::new(|addr, len| unsafe { libc::getsockname(self.inner.0, addr, len) }) } /// Returns the socket address of the remote half of this connection. pub fn peer_addr(&self) -> io::Result { - SocketAddr::new(self.inner.0, libc::getpeername) + SocketAddr::new(|addr, len| unsafe { libc::getpeername(self.inner.0, addr, len) }) } /// Sets the read timeout for the socket. @@ -309,7 +382,7 @@ impl UnixStream { /// Requires the `socket_timeout` feature. #[cfg(feature = "socket_timeout")] pub fn set_read_timeout(&self, timeout: Option) -> io::Result<()> { - self.set_timeout(timeout, libc::SO_RCVTIMEO) + self.inner.set_timeout(timeout, libc::SO_RCVTIMEO) } /// Sets the write timeout for the socket. @@ -321,52 +394,7 @@ impl UnixStream { /// Requires the `socket_timeout` feature. #[cfg(feature = "socket_timeout")] pub fn set_write_timeout(&self, timeout: Option) -> io::Result<()> { - self.set_timeout(timeout, libc::SO_SNDTIMEO) - } - - #[cfg(feature = "socket_timeout")] - fn set_timeout(&self, dur: Option, kind: libc::c_int) -> io::Result<()> { - let timeout = match dur { - Some(dur) => { - if dur.secs() == 0 && dur.extra_nanos() == 0 { - return Err(io::Error::new(io::ErrorKind::InvalidInput, - "cannot set a 0 duration timeout")); - } - - let secs = if dur.secs() > libc::time_t::max_value() as u64 { - libc::time_t::max_value() - } else { - dur.secs() as libc::time_t - }; - let mut timeout = libc::timeval { - tv_sec: secs, - tv_usec: (dur.extra_nanos() / 1000) as libc::suseconds_t, - }; - if timeout.tv_sec == 0 && timeout.tv_usec == 0 { - timeout.tv_usec = 1; - } - timeout - } - None => { - libc::timeval { - tv_sec: 0, - tv_usec: 0, - } - } - }; - - let ret = unsafe { - libc::setsockopt(self.inner.0, - libc::SOL_SOCKET, - kind, - &timeout as *const _ as *const _, - mem::size_of::() as libc::socklen_t) - }; - if ret != 0 { - Err(io::Error::last_os_error()) - } else { - Ok(()) - } + self.inner.set_timeout(timeout, libc::SO_SNDTIMEO) } /// Returns the read timeout of this socket. @@ -374,7 +402,7 @@ impl UnixStream { /// Requires the `socket_timeout` feature. #[cfg(feature = "socket_timeout")] pub fn read_timeout(&self) -> io::Result> { - self.timeout(libc::SO_RCVTIMEO) + self.inner.timeout(libc::SO_RCVTIMEO) } /// Returns the write timeout of this socket. @@ -382,31 +410,7 @@ impl UnixStream { /// Requires the `socket_timeout` feature. #[cfg(feature = "socket_timeout")] pub fn write_timeout(&self) -> io::Result> { - self.timeout(libc::SO_SNDTIMEO) - } - - #[cfg(feature = "socket_timeout")] - fn timeout(&self, kind: libc::c_int) -> io::Result> { - let timeout = unsafe { - let mut timeout: libc::timeval = mem::zeroed(); - let mut size = mem::size_of::() as libc::socklen_t; - let ret = getsockopt(self.inner.0, - libc::SOL_SOCKET, - kind, - &mut timeout as *mut _ as *mut _, - &mut size as *mut _ as *mut _); - if ret != 0 { - return Err(io::Error::last_os_error()); - } - timeout - }; - - if timeout.tv_sec == 0 && timeout.tv_usec == 0 { - Ok(None) - } else { - Ok(Some(std::time::Duration::new(timeout.tv_sec as u64, - (timeout.tv_usec as u32) * 1000))) - } + self.inner.timeout(libc::SO_SNDTIMEO) } /// Shut down the read, write, or both halves of this connection. @@ -415,18 +419,7 @@ impl UnixStream { /// specified portions to immediately return with an appropriate value /// (see the documentation of `Shutdown`). pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { - let how = match how { - Shutdown::Read => libc::SHUT_RD, - Shutdown::Write => libc::SHUT_WR, - Shutdown::Both => libc::SHUT_RDWR, - }; - - let ret = unsafe { libc::shutdown(self.inner.0, how) }; - if ret != 0 { - Err(io::Error::last_os_error()) - } else { - Ok(()) - } + self.inner.shutdown(how) } } @@ -436,28 +429,34 @@ fn calc_len(buf: &[u8]) -> libc::size_t { impl io::Read for UnixStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - let ret = unsafe { - libc::recv(self.inner.0, buf.as_mut_ptr() as *mut _, calc_len(buf), 0) - }; + io::Read::read(&mut &*self, buf) + } +} - if ret < 0 { - Err(io::Error::last_os_error()) - } else { - Ok(ret as usize) +impl<'a> io::Read for &'a UnixStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + unsafe { + cvt_s(libc::recv(self.inner.0, buf.as_mut_ptr() as *mut _, calc_len(buf), 0)) + .map(|r| r as usize) } } } impl io::Write for UnixStream { fn write(&mut self, buf: &[u8]) -> io::Result { - let ret = unsafe { - libc::send(self.inner.0, buf.as_ptr() as *const _, calc_len(buf), 0) - }; + io::Write::write(&mut &*self, buf) + } - if ret < 0 { - Err(io::Error::last_os_error()) - } else { - Ok(ret as usize) + fn flush(&mut self) -> io::Result<()> { + io::Write::flush(&mut &*self) + } +} + +impl<'a> io::Write for &'a UnixStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + unsafe { + cvt_s(libc::send(self.inner.0, buf.as_ptr() as *const _, calc_len(buf), 0)) + .map(|r| r as usize) } } @@ -539,36 +538,23 @@ impl UnixListener { /// corresponding to a path on the filesystem. pub fn bind>(path: P) -> io::Result { unsafe { - let inner = try!(Inner::new()); + let inner = try!(Inner::new(libc::SOCK_STREAM)); let (addr, len) = try!(sockaddr_un(path)); - let ret = libc::bind(inner.0, &addr as *const _ as *const _, len); - if ret < 0 { - return Err(io::Error::last_os_error()); - } + try!(cvt(libc::bind(inner.0, &addr as *const _ as *const _, len))); + try!(cvt(libc::listen(inner.0, 128))); - let ret = libc::listen(inner.0, 128); - if ret < 0 { - Err(io::Error::last_os_error()) - } else { - Ok(UnixListener { - inner: inner, - }) - } + Ok(UnixListener { + inner: inner, + }) } } /// Accepts a new incoming connection to this listener. pub fn accept(&self) -> io::Result { unsafe { - let ret = libc::accept(self.inner.0, 0 as *mut _, 0 as *mut _); - if ret < 0 { - Err(io::Error::last_os_error()) - } else { - Ok(UnixStream { - inner: Inner(ret) - }) - } + cvt(libc::accept(self.inner.0, 0 as *mut _, 0 as *mut _)) + .map(|fd| UnixStream { inner: Inner(fd) }) } } @@ -585,7 +571,7 @@ impl UnixListener { /// Returns the socket address of the local half of this connection. pub fn local_addr(&self) -> io::Result { - SocketAddr::new(self.inner.0, libc::getsockname) + SocketAddr::new(|addr, len| unsafe { libc::getsockname(self.inner.0, addr, len) }) } /// Returns an iterator over incoming connections. @@ -643,6 +629,158 @@ impl<'a> Iterator for Incoming<'a> { } } +/// A Unix datagram socket. +/// +/// # Examples +/// +/// ```rust,no_run +/// use unix_socket::UnixDatagram; +/// +/// let socket = UnixDatagram::bind("/path/to/my/socket").unwrap(); +/// socket.send_to(b"hello world", "/path/to/other/socket").unwrap(); +/// let mut buf = [0; 100]; +/// let (count, address) = socket.recv_from(&mut buf).unwrap(); +/// println!("socket {:?} sent {:?}", address, &buf[..count]); +/// ``` +pub struct UnixDatagram { + inner: Inner, +} + +impl fmt::Debug for UnixDatagram { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + let mut builder = DebugStruct::new(fmt, "UnixDatagram") + .field("fd", &self.inner.0); + if let Ok(addr) = self.local_addr() { + builder = builder.field("local", &addr); + } + builder.finish() + } +} + +impl UnixDatagram { + /// Creates a Unix datagram socket from the given path. + pub fn bind>(path: P) -> io::Result { + unsafe { + let inner = try!(Inner::new(libc::SOCK_DGRAM)); + let (addr, len) = try!(sockaddr_un(path)); + + try!(cvt(libc::bind(inner.0, &addr as *const _ as *const _, len))); + + Ok(UnixDatagram { + inner: inner, + }) + } + } + + /// Returns the address of this socket. + pub fn local_addr(&self) -> io::Result { + SocketAddr::new(|addr, len| unsafe { libc::getsockname(self.inner.0, addr, len) }) + } + + /// Receives data from the socket. + /// + /// On success, returns the number of bytes read and the address from + /// whence the data came. + pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + let mut count = 0; + let addr = try!(SocketAddr::new(|addr, len| { + unsafe { + count = libc::recvfrom(self.inner.0, + buf.as_mut_ptr() as *mut _, + calc_len(buf), + 0, + addr, + len); + if count > 0 { 1 } else if count == 0 { 0 } else { -1 } + } + })); + + Ok((count as usize, addr)) + } + + /// Sends data on the socket to the given address. + /// + /// On success, returns the number of bytes written. + pub fn send_to>(&self, buf: &[u8], path: P) -> io::Result { + unsafe { + let (addr, len) = try!(sockaddr_un(path)); + + let count = try!(cvt_s(libc::sendto(self.inner.0, + buf.as_ptr() as *const _, + calc_len(buf), + 0, + &addr as *const _ as *const _, + len))); + Ok(count as usize) + } + } + + /// Sets the read timeout for the socket. + /// + /// If the provided value is `None`, then `recv_from` calls will block + /// indefinitely. It is an error to pass the zero `Duration` to this + /// method. + /// + /// Requires the `socket_timeout` feature. + #[cfg(feature = "socket_timeout")] + pub fn set_read_timeout(&self, timeout: Option) -> io::Result<()> { + self.inner.set_timeout(timeout, libc::SO_RCVTIMEO) + } + + /// Sets the write timeout for the socket. + /// + /// If the provided value is `None`, then `send_to` calls will block + /// indefinitely. It is an error to pass the zero `Duration` to this + /// method. + /// + /// Requires the `socket_timeout` feature. + #[cfg(feature = "socket_timeout")] + pub fn set_write_timeout(&self, timeout: Option) -> io::Result<()> { + self.inner.set_timeout(timeout, libc::SO_SNDTIMEO) + } + + /// Returns the read timeout of this socket. + /// + /// Requires the `socket_timeout` feature. + #[cfg(feature = "socket_timeout")] + pub fn read_timeout(&self) -> io::Result> { + self.inner.timeout(libc::SO_RCVTIMEO) + } + + /// Returns the write timeout of this socket. + /// + /// Requires the `socket_timeout` feature. + #[cfg(feature = "socket_timeout")] + pub fn write_timeout(&self) -> io::Result> { + self.inner.timeout(libc::SO_SNDTIMEO) + } + + /// Shut down the read, write, or both halves of this connection. + /// + /// This function will cause all pending and future I/O calls on the + /// specified portions to immediately return with an appropriate value + /// (see the documentation of `Shutdown`). + pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { + self.inner.shutdown(how) + } +} + +impl AsRawFd for UnixDatagram { + fn as_raw_fd(&self) -> RawFd { + self.inner.0 + } +} + +#[cfg(feature = "from_raw_fd")] +impl std::os::unix::io::FromRawFd for UnixDatagram { + /// Requires the `from_raw_fd` feature. + unsafe fn from_raw_fd(fd: RawFd) -> UnixDatagram { + UnixDatagram { + inner: Inner(fd) + } + } +} + #[cfg(test)] mod test { extern crate tempdir; @@ -652,7 +790,7 @@ mod test { use std::io::prelude::*; use self::tempdir::TempDir; - use {UnixListener, UnixStream}; + use {UnixListener, UnixStream, UnixDatagram}; macro_rules! or_panic { ($e:expr) => { @@ -883,4 +1021,20 @@ mod test { assert!(wait > Duration::from_millis(400)); assert!(wait < Duration::from_millis(1600)); } + + #[test] + fn test_unix_datagram() { + let dir = or_panic!(TempDir::new("unix_socket")); + let path1 = dir.path().join("sock1"); + let path2 = dir.path().join("sock2"); + + let sock1 = or_panic!(UnixDatagram::bind(&path1)); + let sock2 = or_panic!(UnixDatagram::bind(&path2)); + + let msg = b"hello world"; + or_panic!(sock1.send_to(msg, &path2)); + let mut buf = [0; 11]; + or_panic!(sock2.recv_from(&mut buf)); + assert_eq!(msg, &buf[..]); + } }