Skip to content

Commit

Permalink
refactor: reduce ease of soundness mistake of peek_ahead in iter.rs (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
hkBst authored Jan 10, 2025
1 parent f1cbffc commit eb91549
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 28 deletions.
46 changes: 27 additions & 19 deletions src/iter.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use core::convert::TryInto;
use core::convert::TryFrom;
use core::convert::TryInto;

#[allow(missing_docs)]
pub struct Bytes<'a> {
Expand Down Expand Up @@ -41,19 +41,27 @@ impl<'a> Bytes<'a> {
}
}

#[inline]
pub fn peek_ahead(&self, n: usize) -> Option<u8> {
// SAFETY: obtain a potentially OOB pointer that is later compared against the `self.end`
// pointer.
let ptr = self.cursor.wrapping_add(n);
if ptr < self.end {
// SAFETY: bounds checked pointer dereference is safe
Some(unsafe { *ptr })
/// Peek at byte `n` ahead of cursor
///
/// # Safety
///
/// Caller must ensure that `n <= self.len()`, otherwise `self.cursor.add(n)` is UB.
/// That means there are at least `n-1` bytes between `self.cursor` and `self.end`
/// and `self.cursor.add(n)` is either `self.end` or points to a valid byte.
#[inline]
pub unsafe fn peek_ahead(&self, n: usize) -> Option<u8> {
debug_assert!(n <= self.len());
// SAFETY: by preconditions
let p = unsafe { self.cursor.add(n) };

Check warning on line 55 in src/iter.rs

View workflow job for this annotation

GitHub Actions / msrv (x64)

unnecessary `unsafe` block
if p < self.end {
// SAFETY: by preconditions, if this is not `self.end`,
// then it is safe to dereference
Some(unsafe { *p })

Check warning on line 59 in src/iter.rs

View workflow job for this annotation

GitHub Actions / msrv (x64)

unnecessary `unsafe` block

Check warning on line 59 in src/iter.rs

View workflow job for this annotation

GitHub Actions / msrv (aarch64)

unnecessary `unsafe` block
} else {
None
}
}

#[inline]
pub fn peek_n<'b: 'a, U: TryFrom<&'a [u8]>>(&'b self, n: usize) -> Option<U> {
// TODO: once we bump MSRC, use const generics to allow only [u8; N] reads
Expand All @@ -65,7 +73,7 @@ impl<'a> Bytes<'a> {
/// Advance by 1, equivalent to calling `advance(1)`.
///
/// # Safety
///
///
/// Caller must ensure that Bytes hasn't been advanced/bumped by more than [`Bytes::len()`].
#[inline]
pub unsafe fn bump(&mut self) {
Expand All @@ -75,7 +83,7 @@ impl<'a> Bytes<'a> {
/// Advance cursor by `n`
///
/// # Safety
///
///
/// Caller must ensure that Bytes hasn't been advanced/bumped by more than [`Bytes::len()`].
#[inline]
pub unsafe fn advance(&mut self, n: usize) {
Expand Down Expand Up @@ -104,7 +112,7 @@ impl<'a> Bytes<'a> {
// TODO: this is an anti-pattern, should be removed
/// Deprecated. Do not use!
/// # Safety
///
///
/// Caller must ensure that `skip` is at most the number of advances (i.e., `bytes.advance(3)`
/// implies a skip of at most 3).
#[inline]
Expand All @@ -114,21 +122,21 @@ impl<'a> Bytes<'a> {
self.commit();
head
}

#[inline]
pub fn commit(&mut self) {
self.start = self.cursor
}

/// # Safety
///
///
/// see [`Bytes::advance`] safety comment.
#[inline]
pub unsafe fn advance_and_commit(&mut self, n: usize) {
self.advance(n);
self.commit();
}

#[inline]
pub fn as_ptr(&self) -> *const u8 {
self.cursor
Expand All @@ -138,14 +146,14 @@ impl<'a> Bytes<'a> {
pub fn start(&self) -> *const u8 {
self.start
}

#[inline]
pub fn end(&self) -> *const u8 {
self.end
}

/// # Safety
///
///
/// Must ensure invariant `bytes.start() <= ptr && ptr <= bytes.end()`.
#[inline]
pub unsafe fn set_cursor(&mut self, ptr: *const u8) {
Expand Down
21 changes: 12 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -840,20 +840,23 @@ pub fn parse_method<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> {
const POST: [u8; 4] = *b"POST";
match bytes.peek_n::<[u8; 4]>(4) {
Some(GET) => {
// SAFETY: matched the ASCII string and boundary checked
// SAFETY: we matched "GET " which has 4 bytes and is ASCII
let method = unsafe {
bytes.advance(4);
let buf = bytes.slice_skip(1);
str::from_utf8_unchecked(buf)
bytes.advance(4); // advance cursor past "GET "
str::from_utf8_unchecked(bytes.slice_skip(1)) // "GET" without space
};
Ok(Status::Complete(method))
}
Some(POST) if bytes.peek_ahead(4) == Some(b' ') => {
// SAFETY: matched the ASCII string and boundary checked
// SAFETY:
// If `bytes.peek_n...` returns a Some([u8; 4]),
// then we are assured that `bytes` contains at least 4 bytes.
// Thus `bytes.len() >= 4`,
// and it is safe to peek at byte 4 with `bytes.peek_ahead(4)`.
Some(POST) if unsafe { bytes.peek_ahead(4) } == Some(b' ') => {
// SAFETY: we matched "POST " which has 5 bytes
let method = unsafe {
bytes.advance(5);
let buf = bytes.slice_skip(1);
str::from_utf8_unchecked(buf)
bytes.advance(5); // advance cursor past "POST "
str::from_utf8_unchecked(bytes.slice_skip(1)) // "POST" without space
};
Ok(Status::Complete(method))
}
Expand Down

0 comments on commit eb91549

Please sign in to comment.