diff --git a/src/iter.rs b/src/iter.rs index 4c5fa88..98c1d7b 100644 --- a/src/iter.rs +++ b/src/iter.rs @@ -1,5 +1,5 @@ -use core::convert::TryInto; use core::convert::TryFrom; +use core::convert::TryInto; #[allow(missing_docs)] pub struct Bytes<'a> { @@ -41,19 +41,27 @@ impl<'a> Bytes<'a> { } } - #[inline] - pub fn peek_ahead(&self, n: usize) -> Option<u8> { - // SAFETY: obtain a potentially OOB pointer that is later compared against the `self.end` - // pointer. - let ptr = self.cursor.wrapping_add(n); - if ptr < self.end { - // SAFETY: bounds checked pointer dereference is safe - Some(unsafe { *ptr }) + /// Peek at byte `n` ahead of cursor + /// + /// # Safety + /// + /// Caller must ensure that `n <= self.len()`, otherwise `self.cursor.add(n)` is UB. + /// That means there are at least `n-1` bytes between `self.cursor` and `self.end` + /// and `self.cursor.add(n)` is either `self.end` or points to a valid byte. + #[inline] + pub unsafe fn peek_ahead(&self, n: usize) -> Option<u8> { + debug_assert!(n <= self.len()); + // SAFETY: by preconditions + let p = unsafe { self.cursor.add(n) }; + if p < self.end { + // SAFETY: by preconditions, if this is not `self.end`, + // then it is safe to dereference + Some(unsafe { *p }) } else { None } } - + #[inline] pub fn peek_n<'b: 'a, U: TryFrom<&'a [u8]>>(&'b self, n: usize) -> Option<U> { // TODO: once we bump MSRC, use const generics to allow only [u8; N] reads @@ -65,7 +73,7 @@ impl<'a> Bytes<'a> { /// Advance by 1, equivalent to calling `advance(1)`. /// /// # Safety - /// + /// /// Caller must ensure that Bytes hasn't been advanced/bumped by more than [`Bytes::len()`]. #[inline] pub unsafe fn bump(&mut self) { @@ -75,7 +83,7 @@ impl<'a> Bytes<'a> { /// Advance cursor by `n` /// /// # Safety - /// + /// /// Caller must ensure that Bytes hasn't been advanced/bumped by more than [`Bytes::len()`]. #[inline] pub unsafe fn advance(&mut self, n: usize) { @@ -104,7 +112,7 @@ impl<'a> Bytes<'a> { // TODO: this is an anti-pattern, should be removed /// Deprecated. Do not use! /// # Safety - /// + /// /// Caller must ensure that `skip` is at most the number of advances (i.e., `bytes.advance(3)` /// implies a skip of at most 3). #[inline] @@ -114,21 +122,21 @@ impl<'a> Bytes<'a> { self.commit(); head } - + #[inline] pub fn commit(&mut self) { self.start = self.cursor } /// # Safety - /// + /// /// see [`Bytes::advance`] safety comment. #[inline] pub unsafe fn advance_and_commit(&mut self, n: usize) { self.advance(n); self.commit(); } - + #[inline] pub fn as_ptr(&self) -> *const u8 { self.cursor @@ -138,14 +146,14 @@ impl<'a> Bytes<'a> { pub fn start(&self) -> *const u8 { self.start } - + #[inline] pub fn end(&self) -> *const u8 { self.end } - + /// # Safety - /// + /// /// Must ensure invariant `bytes.start() <= ptr && ptr <= bytes.end()`. #[inline] pub unsafe fn set_cursor(&mut self, ptr: *const u8) { diff --git a/src/lib.rs b/src/lib.rs index 076c582..bb9b61d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -840,20 +840,23 @@ pub fn parse_method<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { const POST: [u8; 4] = *b"POST"; match bytes.peek_n::<[u8; 4]>(4) { Some(GET) => { - // SAFETY: matched the ASCII string and boundary checked + // SAFETY: we matched "GET " which has 4 bytes and is ASCII let method = unsafe { - bytes.advance(4); - let buf = bytes.slice_skip(1); - str::from_utf8_unchecked(buf) + bytes.advance(4); // advance cursor past "GET " + str::from_utf8_unchecked(bytes.slice_skip(1)) // "GET" without space }; Ok(Status::Complete(method)) } - Some(POST) if bytes.peek_ahead(4) == Some(b' ') => { - // SAFETY: matched the ASCII string and boundary checked + // SAFETY: + // If `bytes.peek_n...` returns a Some([u8; 4]), + // then we are assured that `bytes` contains at least 4 bytes. + // Thus `bytes.len() >= 4`, + // and it is safe to peek at byte 4 with `bytes.peek_ahead(4)`. + Some(POST) if unsafe { bytes.peek_ahead(4) } == Some(b' ') => { + // SAFETY: we matched "POST " which has 5 bytes let method = unsafe { - bytes.advance(5); - let buf = bytes.slice_skip(1); - str::from_utf8_unchecked(buf) + bytes.advance(5); // advance cursor past "POST " + str::from_utf8_unchecked(bytes.slice_skip(1)) // "POST" without space }; Ok(Status::Complete(method)) }