Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,9 @@ edition = "2018"

[dependencies]
bitflags = "1.2.0"
libc = "0.2.66"
nix = "0.16.0"
uring-sys = "1.0.0-beta"

[dev-dependencies]
semver = "0.9.0"

14 changes: 13 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,22 @@ use std::mem::MaybeUninit;
use std::ptr::{self, NonNull};
use std::time::Duration;

pub use sqe::{SubmissionQueue, SubmissionQueueEvent, SubmissionFlags, FsyncFlags, PollMask};
pub use sqe::{SubmissionQueue, SubmissionQueueEvent, SubmissionFlags, FsyncFlags, SockAddrStorage};
pub use cqe::{CompletionQueue, CompletionQueueEvent};
pub use registrar::Registrar;

pub use nix::poll::PollFlags;
pub use nix::sys::socket::{
SockAddr,
SockFlag,
InetAddr,
UnixAddr,
NetlinkAddr,
AlgAddr,
LinkAddr,
VsockAddr
};

bitflags::bitflags! {
/// `IoUring` initialization flags for advanced use cases.
///
Expand Down
70 changes: 55 additions & 15 deletions src/sqe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::marker::PhantomData;
use std::time::Duration;

use super::IoUring;
use super::{PollFlags, SockAddr, SockFlag};

/// The queue of pending IO events.
///
Expand Down Expand Up @@ -274,15 +275,40 @@ impl<'a> SubmissionQueueEvent<'a> {
}

#[inline]
pub unsafe fn prep_poll_add(&mut self, fd: RawFd, poll_mask: PollMask) {
uring_sys::io_uring_prep_poll_add(self.sqe, fd, poll_mask.bits())
pub unsafe fn prep_timeout_remove(&mut self, user_data: u64) {
uring_sys::io_uring_prep_timeout_remove(self.sqe, user_data as _, 0);
}

#[inline]
pub unsafe fn prep_link_timeout(&mut self, ts: &uring_sys::__kernel_timespec) {
uring_sys::io_uring_prep_link_timeout(self.sqe, ts as *const _ as *mut _, 0);
}

#[inline]
pub unsafe fn prep_poll_add(&mut self, fd: RawFd, poll_flags: PollFlags) {
uring_sys::io_uring_prep_poll_add(self.sqe, fd, poll_flags.bits())
}

#[inline]
pub unsafe fn prep_poll_remove(&mut self, user_data: u64) {
uring_sys::io_uring_prep_poll_remove(self.sqe, user_data as _)
}

#[inline]
pub unsafe fn prep_connect(&mut self, fd: RawFd, socket_addr: &SockAddr) {
let (addr, len) = socket_addr.as_ffi_pair();
uring_sys::io_uring_prep_connect(self.sqe, fd, addr as *const _ as *mut _, len);
}

#[inline]
pub unsafe fn prep_accept(&mut self, fd: RawFd, accept: Option<&mut SockAddrStorage>, flags: SockFlag) {
let (addr, len) = match accept {
Some(accept) => (accept.storage.as_mut_ptr() as *mut _, &mut accept.len as *mut _ as *mut _),
None => (std::ptr::null_mut(), std::ptr::null_mut())
};
uring_sys::io_uring_prep_accept(self.sqe, fd, addr, len, flags.bits())
}

/// Prepare a no-op event.
/// ```
/// # use iou::{IoUring, SubmissionFlags};
Expand Down Expand Up @@ -358,6 +384,33 @@ impl<'a> SubmissionQueueEvent<'a> {
unsafe impl<'a> Send for SubmissionQueueEvent<'a> { }
unsafe impl<'a> Sync for SubmissionQueueEvent<'a> { }

pub struct SockAddrStorage {
storage: mem::MaybeUninit<nix::sys::socket::sockaddr_storage>,
len: usize,
}

impl SockAddrStorage {
pub fn uninit() -> Self {
let storage = mem::MaybeUninit::uninit();
let len = mem::size_of::<nix::sys::socket::sockaddr_storage>();
SockAddrStorage {
storage,
len
}
}

pub unsafe fn as_socket_addr(&self) -> io::Result<SockAddr> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to make sure I have a correct understanding of this API: this function is safe to call after as the prep_accept SQE you passed this to has completed, and calling it before it has completed would be incorrect. Is that right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely right.

let storage = &*self.storage.as_ptr();
nix::sys::socket::sockaddr_storage_to_addr(storage, self.len).map_err(|e| {
let err_no = e.as_errno();
match err_no {
Some(err_no) => io::Error::from_raw_os_error(err_no as _),
None => io::Error::new(io::ErrorKind::Other, "Unknown error")
}
})
}
}

bitflags::bitflags! {
/// [`SubmissionQueueEvent`](SubmissionQueueEvent) configuration flags.
///
Expand Down Expand Up @@ -387,16 +440,3 @@ bitflags::bitflags! {
const TIMEOUT_ABS = 1 << 0;
}
}

bitflags::bitflags! {
pub struct PollMask: libc::c_short {
const POLLIN = libc::POLLIN;
const POLLPRI = libc::POLLPRI;
const POLLOUT = libc::POLLOUT;
const POLLERR = libc::POLLERR;
const POLLHUP = libc::POLLHUP;
const POLLNVAL = libc::POLLNVAL;
const POLLRDNORM = libc::POLLRDNORM;
const POLLRDBAND = libc::POLLRDBAND;
}
}
68 changes: 68 additions & 0 deletions tests/accept.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
use nix::sys::socket::InetAddr;
use std::{
io::{self, Read, Write},
net::{TcpListener, TcpStream},
os::unix::io::{AsRawFd, FromRawFd},
};
use iou::SockAddr;

const MESSAGE: &'static [u8] = b"Hello World";

#[test]
#[ignore] // kernel 5.5 needed for accept
fn accept() -> io::Result<()> {
let mut ring = iou::IoUring::new(1)?;

let listener = TcpListener::bind(("0.0.0.0", 0))?;
listener.set_nonblocking(true)?;

let mut stream = TcpStream::connect(listener.local_addr()?)?;
stream.write_all(MESSAGE)?;

let fd = listener.as_raw_fd();
let mut sq = ring.sq();
let mut sqe = sq.next_sqe().expect("failed to get sqe");
unsafe {
sqe.prep_accept(fd, None, iou::SockFlag::empty());
sq.submit()?;
}
let cqe = ring.wait_for_cqe()?;
let accept_fd = cqe.result()?;
let mut accept_buf = [0; MESSAGE.len()];
let mut stream = unsafe { TcpStream::from_raw_fd(accept_fd as _) };
stream.read_exact(&mut accept_buf)?;
assert_eq!(accept_buf, MESSAGE);
Ok(())
}

#[test]
#[ignore] // kernel 5.5 needed for accept
fn accept_with_params() -> io::Result<()> {
let mut ring = iou::IoUring::new(1)?;

let listener = TcpListener::bind(("0.0.0.0", 0))?;
listener.set_nonblocking(true)?;

let mut connection_stream = TcpStream::connect(listener.local_addr()?)?;
connection_stream.write_all(MESSAGE)?;

let fd = listener.as_raw_fd();
let mut sq = ring.sq();
let mut sqe = sq.next_sqe().expect("failed to get sqe");
let mut accept_params = iou::SockAddrStorage::uninit();
unsafe {
sqe.prep_accept(fd, Some(&mut accept_params), iou::SockFlag::empty());
sq.submit()?;
}
let cqe = ring.wait_for_cqe()?;
let accept_fd = cqe.result()?;
let mut accept_buf = [0; MESSAGE.len()];
let mut accepted_stream = unsafe { TcpStream::from_raw_fd(accept_fd as _) };
accepted_stream.read_exact(&mut accept_buf)?;
assert_eq!(accept_buf, MESSAGE);

let addr = unsafe { accept_params.as_socket_addr()? };
let connection_addr = SockAddr::Inet(InetAddr::from_std(&connection_stream.local_addr()?));
assert_eq!(addr, connection_addr);
Ok(())
}
30 changes: 30 additions & 0 deletions tests/connect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use nix::sys::socket::{AddressFamily, SockProtocol, SockType, InetAddr, SockFlag};
use std::{io, net::TcpListener};

#[test]
#[ignore] // kernel 5.5 needed for connect
fn connect() -> io::Result<()> {
let listener = TcpListener::bind(("0.0.0.0", 0))?;
listener.set_nonblocking(true)?;
let listener_addr = iou::SockAddr::new_inet(InetAddr::from_std(&listener.local_addr()?));

let socket = nix::sys::socket::socket(
AddressFamily::Inet,
SockType::Stream,
SockFlag::SOCK_NONBLOCK,
SockProtocol::Tcp,
)
.map_err(|_| io::Error::new(io::ErrorKind::Other, "failed to create socket"))?;

let mut ring = iou::IoUring::new(1)?;
let mut sqe = ring.next_sqe().expect("failed to get sqe");
unsafe {
sqe.prep_connect(socket, &listener_addr);
sqe.set_user_data(42);
ring.submit_sqes()?;
}
let cqe = ring.wait_for_cqe()?;
let _res = cqe.result()?;
assert_eq!(cqe.user_data(), 42);
Ok(())
}
85 changes: 36 additions & 49 deletions tests/poll.rs
Original file line number Diff line number Diff line change
@@ -1,79 +1,66 @@
#![feature(test)]
extern crate libc;
extern crate test;

use std::{io, os::unix::io::RawFd};

pub fn pipe() -> io::Result<(RawFd, RawFd)> {
unsafe {
let mut fds = core::mem::MaybeUninit::<[libc::c_int; 2]>::uninit();

let res = libc::pipe(fds.as_mut_ptr() as *mut libc::c_int);

if res < 0 {
Err(io::Error::from_raw_os_error(-res))
} else {
Ok((fds.assume_init()[0], fds.assume_init()[1]))
}
}
}
use std::{
io::{self, Read, Write},
os::unix::{io::AsRawFd, net},
};
const MESSAGE: &'static [u8] = b"Hello World";

#[test]
fn test_poll_add() -> io::Result<()> {
let mut ring = iou::IoUring::new(2)?;
let (read, write) = pipe()?;

let (mut read, mut write) = net::UnixStream::pair()?;
unsafe {
let mut sqe = ring.next_sqe().expect("no sqe");
sqe.prep_poll_add(read, iou::PollMask::POLLIN);
let mut sqe = ring.next_sqe().expect("failed to get sqe");
sqe.prep_poll_add(read.as_raw_fd(), iou::PollFlags::POLLIN);
sqe.set_user_data(0xDEADBEEF);
ring.submit_sqes()?;
}

let res = unsafe {
let buf = b"hello";
libc::write(
write,
buf.as_ptr() as *const libc::c_void,
buf.len() as libc::size_t,
)
};

if res < 0 {
return Err(io::Error::from_raw_os_error(-res as _));
}
write.write(MESSAGE)?;

let cqe = ring.wait_for_cqe()?;
assert_eq!(cqe.user_data(), 0xDEADBEEF);
let mask = unsafe { iou::PollMask::from_bits_unchecked(cqe.result()? as _) };
assert!(mask.contains(iou::PollMask::POLLIN));
unsafe {
libc::close(write);
libc::close(read);
}
let mask = unsafe { iou::PollFlags::from_bits_unchecked(cqe.result()? as _) };
assert!(mask.contains(iou::PollFlags::POLLIN));
let mut buf = [0; MESSAGE.len()];
read.read(&mut buf)?;
assert_eq!(buf, MESSAGE);
Ok(())
}

#[test]
fn test_poll_remove() -> io::Result<()> {
let mut ring = iou::IoUring::new(2)?;
let (read, write) = pipe()?;

let (read, _write) = net::UnixStream::pair()?;
let uname = nix::sys::utsname::uname();
let version = semver::Version::parse(uname.release());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good way to solve running tests only under certain versions of Linux, a problem also discussed in #6 and #17. After we merge this, it would be good to track developing a lightweight API for running tests differently before/after a certain version string.

unsafe {
let mut sqe = ring.next_sqe().expect("no sqe");
sqe.prep_poll_add(read, iou::PollMask::POLLIN);
let mut sqe = ring.next_sqe().expect("failed to get sqe");
sqe.prep_poll_add(read.as_raw_fd(), iou::PollFlags::POLLIN);
sqe.set_user_data(0xDEADBEEF);
ring.submit_sqes()?;

let mut sqe = ring.next_sqe().expect("no sqe");
let mut sqe = ring.next_sqe().expect("failed to get sqe");
sqe.prep_poll_remove(0xDEADBEEF);
sqe.set_user_data(42);
ring.submit_sqes()?;
for _ in 0..2 {
let cqe = ring.wait_for_cqe()?;
let _ = cqe.result()?;
let user_data = cqe.user_data();
if version < semver::Version::parse("5.5.0-0") {
let _ = cqe.result()?;
} else if user_data == 0xDEADBEEF {
let err = cqe
.result()
.expect_err("on kernels >=5.5 error is expected");
let err_no = nix::errno::Errno::from_i32(
err.raw_os_error()
.expect("on kernels >=5.5 os_error is expected"),
);
assert_eq!(err_no, nix::errno::Errno::ECANCELED);
} else {
let _ = cqe.result()?;
}
}
libc::close(write);
libc::close(read);
Ok(())
}
}