diff --git a/Cargo.toml b/Cargo.toml index 72ec6be..8278994 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,5 +11,9 @@ edition = "2018" [dependencies] bitflags = "1.2.0" -libc = "0.2.66" +nix = "0.16.0" uring-sys = "1.0.0-beta" + +[dev-dependencies] +semver = "0.9.0" + diff --git a/src/lib.rs b/src/lib.rs index 9a93fd2..24b31c6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -55,10 +55,22 @@ use std::mem::MaybeUninit; use std::ptr::{self, NonNull}; use std::time::Duration; -pub use sqe::{SubmissionQueue, SubmissionQueueEvent, SubmissionFlags, FsyncFlags, PollMask}; +pub use sqe::{SubmissionQueue, SubmissionQueueEvent, SubmissionFlags, FsyncFlags, SockAddrStorage}; pub use cqe::{CompletionQueue, CompletionQueueEvent}; pub use registrar::Registrar; +pub use nix::poll::PollFlags; +pub use nix::sys::socket::{ + SockAddr, + SockFlag, + InetAddr, + UnixAddr, + NetlinkAddr, + AlgAddr, + LinkAddr, + VsockAddr +}; + bitflags::bitflags! { /// `IoUring` initialization flags for advanced use cases. /// diff --git a/src/sqe.rs b/src/sqe.rs index 1d8c202..1f53267 100644 --- a/src/sqe.rs +++ b/src/sqe.rs @@ -6,6 +6,7 @@ use std::marker::PhantomData; use std::time::Duration; use super::IoUring; +use super::{PollFlags, SockAddr, SockFlag}; /// The queue of pending IO events. /// @@ -274,8 +275,18 @@ impl<'a> SubmissionQueueEvent<'a> { } #[inline] - pub unsafe fn prep_poll_add(&mut self, fd: RawFd, poll_mask: PollMask) { - uring_sys::io_uring_prep_poll_add(self.sqe, fd, poll_mask.bits()) + pub unsafe fn prep_timeout_remove(&mut self, user_data: u64) { + uring_sys::io_uring_prep_timeout_remove(self.sqe, user_data as _, 0); + } + + #[inline] + pub unsafe fn prep_link_timeout(&mut self, ts: &uring_sys::__kernel_timespec) { + uring_sys::io_uring_prep_link_timeout(self.sqe, ts as *const _ as *mut _, 0); + } + + #[inline] + pub unsafe fn prep_poll_add(&mut self, fd: RawFd, poll_flags: PollFlags) { + uring_sys::io_uring_prep_poll_add(self.sqe, fd, poll_flags.bits()) } #[inline] @@ -283,6 +294,21 @@ impl<'a> SubmissionQueueEvent<'a> { uring_sys::io_uring_prep_poll_remove(self.sqe, user_data as _) } + #[inline] + pub unsafe fn prep_connect(&mut self, fd: RawFd, socket_addr: &SockAddr) { + let (addr, len) = socket_addr.as_ffi_pair(); + uring_sys::io_uring_prep_connect(self.sqe, fd, addr as *const _ as *mut _, len); + } + + #[inline] + pub unsafe fn prep_accept(&mut self, fd: RawFd, accept: Option<&mut SockAddrStorage>, flags: SockFlag) { + let (addr, len) = match accept { + Some(accept) => (accept.storage.as_mut_ptr() as *mut _, &mut accept.len as *mut _ as *mut _), + None => (std::ptr::null_mut(), std::ptr::null_mut()) + }; + uring_sys::io_uring_prep_accept(self.sqe, fd, addr, len, flags.bits()) + } + /// Prepare a no-op event. /// ``` /// # use iou::{IoUring, SubmissionFlags}; @@ -358,6 +384,33 @@ impl<'a> SubmissionQueueEvent<'a> { unsafe impl<'a> Send for SubmissionQueueEvent<'a> { } unsafe impl<'a> Sync for SubmissionQueueEvent<'a> { } +pub struct SockAddrStorage { + storage: mem::MaybeUninit, + len: usize, +} + +impl SockAddrStorage { + pub fn uninit() -> Self { + let storage = mem::MaybeUninit::uninit(); + let len = mem::size_of::(); + SockAddrStorage { + storage, + len + } + } + + pub unsafe fn as_socket_addr(&self) -> io::Result { + let storage = &*self.storage.as_ptr(); + nix::sys::socket::sockaddr_storage_to_addr(storage, self.len).map_err(|e| { + let err_no = e.as_errno(); + match err_no { + Some(err_no) => io::Error::from_raw_os_error(err_no as _), + None => io::Error::new(io::ErrorKind::Other, "Unknown error") + } + }) + } +} + bitflags::bitflags! { /// [`SubmissionQueueEvent`](SubmissionQueueEvent) configuration flags. /// @@ -387,16 +440,3 @@ bitflags::bitflags! { const TIMEOUT_ABS = 1 << 0; } } - -bitflags::bitflags! { - pub struct PollMask: libc::c_short { - const POLLIN = libc::POLLIN; - const POLLPRI = libc::POLLPRI; - const POLLOUT = libc::POLLOUT; - const POLLERR = libc::POLLERR; - const POLLHUP = libc::POLLHUP; - const POLLNVAL = libc::POLLNVAL; - const POLLRDNORM = libc::POLLRDNORM; - const POLLRDBAND = libc::POLLRDBAND; - } -} diff --git a/tests/accept.rs b/tests/accept.rs new file mode 100644 index 0000000..af3e533 --- /dev/null +++ b/tests/accept.rs @@ -0,0 +1,68 @@ +use nix::sys::socket::InetAddr; +use std::{ + io::{self, Read, Write}, + net::{TcpListener, TcpStream}, + os::unix::io::{AsRawFd, FromRawFd}, +}; +use iou::SockAddr; + +const MESSAGE: &'static [u8] = b"Hello World"; + +#[test] +#[ignore] // kernel 5.5 needed for accept +fn accept() -> io::Result<()> { + let mut ring = iou::IoUring::new(1)?; + + let listener = TcpListener::bind(("0.0.0.0", 0))?; + listener.set_nonblocking(true)?; + + let mut stream = TcpStream::connect(listener.local_addr()?)?; + stream.write_all(MESSAGE)?; + + let fd = listener.as_raw_fd(); + let mut sq = ring.sq(); + let mut sqe = sq.next_sqe().expect("failed to get sqe"); + unsafe { + sqe.prep_accept(fd, None, iou::SockFlag::empty()); + sq.submit()?; + } + let cqe = ring.wait_for_cqe()?; + let accept_fd = cqe.result()?; + let mut accept_buf = [0; MESSAGE.len()]; + let mut stream = unsafe { TcpStream::from_raw_fd(accept_fd as _) }; + stream.read_exact(&mut accept_buf)?; + assert_eq!(accept_buf, MESSAGE); + Ok(()) +} + +#[test] +#[ignore] // kernel 5.5 needed for accept +fn accept_with_params() -> io::Result<()> { + let mut ring = iou::IoUring::new(1)?; + + let listener = TcpListener::bind(("0.0.0.0", 0))?; + listener.set_nonblocking(true)?; + + let mut connection_stream = TcpStream::connect(listener.local_addr()?)?; + connection_stream.write_all(MESSAGE)?; + + let fd = listener.as_raw_fd(); + let mut sq = ring.sq(); + let mut sqe = sq.next_sqe().expect("failed to get sqe"); + let mut accept_params = iou::SockAddrStorage::uninit(); + unsafe { + sqe.prep_accept(fd, Some(&mut accept_params), iou::SockFlag::empty()); + sq.submit()?; + } + let cqe = ring.wait_for_cqe()?; + let accept_fd = cqe.result()?; + let mut accept_buf = [0; MESSAGE.len()]; + let mut accepted_stream = unsafe { TcpStream::from_raw_fd(accept_fd as _) }; + accepted_stream.read_exact(&mut accept_buf)?; + assert_eq!(accept_buf, MESSAGE); + + let addr = unsafe { accept_params.as_socket_addr()? }; + let connection_addr = SockAddr::Inet(InetAddr::from_std(&connection_stream.local_addr()?)); + assert_eq!(addr, connection_addr); + Ok(()) +} \ No newline at end of file diff --git a/tests/connect.rs b/tests/connect.rs new file mode 100644 index 0000000..522c800 --- /dev/null +++ b/tests/connect.rs @@ -0,0 +1,30 @@ +use nix::sys::socket::{AddressFamily, SockProtocol, SockType, InetAddr, SockFlag}; +use std::{io, net::TcpListener}; + +#[test] +#[ignore] // kernel 5.5 needed for connect +fn connect() -> io::Result<()> { + let listener = TcpListener::bind(("0.0.0.0", 0))?; + listener.set_nonblocking(true)?; + let listener_addr = iou::SockAddr::new_inet(InetAddr::from_std(&listener.local_addr()?)); + + let socket = nix::sys::socket::socket( + AddressFamily::Inet, + SockType::Stream, + SockFlag::SOCK_NONBLOCK, + SockProtocol::Tcp, + ) + .map_err(|_| io::Error::new(io::ErrorKind::Other, "failed to create socket"))?; + + let mut ring = iou::IoUring::new(1)?; + let mut sqe = ring.next_sqe().expect("failed to get sqe"); + unsafe { + sqe.prep_connect(socket, &listener_addr); + sqe.set_user_data(42); + ring.submit_sqes()?; + } + let cqe = ring.wait_for_cqe()?; + let _res = cqe.result()?; + assert_eq!(cqe.user_data(), 42); + Ok(()) +} diff --git a/tests/poll.rs b/tests/poll.rs index 9df5303..2d53893 100644 --- a/tests/poll.rs +++ b/tests/poll.rs @@ -1,79 +1,66 @@ -#![feature(test)] -extern crate libc; -extern crate test; - -use std::{io, os::unix::io::RawFd}; - -pub fn pipe() -> io::Result<(RawFd, RawFd)> { - unsafe { - let mut fds = core::mem::MaybeUninit::<[libc::c_int; 2]>::uninit(); - - let res = libc::pipe(fds.as_mut_ptr() as *mut libc::c_int); - - if res < 0 { - Err(io::Error::from_raw_os_error(-res)) - } else { - Ok((fds.assume_init()[0], fds.assume_init()[1])) - } - } -} +use std::{ + io::{self, Read, Write}, + os::unix::{io::AsRawFd, net}, +}; +const MESSAGE: &'static [u8] = b"Hello World"; #[test] fn test_poll_add() -> io::Result<()> { let mut ring = iou::IoUring::new(2)?; - let (read, write) = pipe()?; - + let (mut read, mut write) = net::UnixStream::pair()?; unsafe { - let mut sqe = ring.next_sqe().expect("no sqe"); - sqe.prep_poll_add(read, iou::PollMask::POLLIN); + let mut sqe = ring.next_sqe().expect("failed to get sqe"); + sqe.prep_poll_add(read.as_raw_fd(), iou::PollFlags::POLLIN); sqe.set_user_data(0xDEADBEEF); ring.submit_sqes()?; } - let res = unsafe { - let buf = b"hello"; - libc::write( - write, - buf.as_ptr() as *const libc::c_void, - buf.len() as libc::size_t, - ) - }; - - if res < 0 { - return Err(io::Error::from_raw_os_error(-res as _)); - } + write.write(MESSAGE)?; let cqe = ring.wait_for_cqe()?; assert_eq!(cqe.user_data(), 0xDEADBEEF); - let mask = unsafe { iou::PollMask::from_bits_unchecked(cqe.result()? as _) }; - assert!(mask.contains(iou::PollMask::POLLIN)); - unsafe { - libc::close(write); - libc::close(read); - } + let mask = unsafe { iou::PollFlags::from_bits_unchecked(cqe.result()? as _) }; + assert!(mask.contains(iou::PollFlags::POLLIN)); + let mut buf = [0; MESSAGE.len()]; + read.read(&mut buf)?; + assert_eq!(buf, MESSAGE); Ok(()) } #[test] fn test_poll_remove() -> io::Result<()> { let mut ring = iou::IoUring::new(2)?; - let (read, write) = pipe()?; - + let (read, _write) = net::UnixStream::pair()?; + let uname = nix::sys::utsname::uname(); + let version = semver::Version::parse(uname.release()); unsafe { - let mut sqe = ring.next_sqe().expect("no sqe"); - sqe.prep_poll_add(read, iou::PollMask::POLLIN); + let mut sqe = ring.next_sqe().expect("failed to get sqe"); + sqe.prep_poll_add(read.as_raw_fd(), iou::PollFlags::POLLIN); sqe.set_user_data(0xDEADBEEF); ring.submit_sqes()?; - let mut sqe = ring.next_sqe().expect("no sqe"); + let mut sqe = ring.next_sqe().expect("failed to get sqe"); sqe.prep_poll_remove(0xDEADBEEF); + sqe.set_user_data(42); ring.submit_sqes()?; for _ in 0..2 { let cqe = ring.wait_for_cqe()?; - let _ = cqe.result()?; + let user_data = cqe.user_data(); + if version < semver::Version::parse("5.5.0-0") { + let _ = cqe.result()?; + } else if user_data == 0xDEADBEEF { + let err = cqe + .result() + .expect_err("on kernels >=5.5 error is expected"); + let err_no = nix::errno::Errno::from_i32( + err.raw_os_error() + .expect("on kernels >=5.5 os_error is expected"), + ); + assert_eq!(err_no, nix::errno::Errno::ECANCELED); + } else { + let _ = cqe.result()?; + } } - libc::close(write); - libc::close(read); Ok(()) } }