Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 120 additions & 9 deletions src/riscv/lib/src/pvm/outbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,25 @@ impl<M: AtomMode> Outbox<M> {
message: OutboxMessage,
current_level: u32,
) -> Result<(), OutboxError> {
let level_index = current_level as usize % self.levels.len();
let level_index = self.level_index(current_level);
self.levels[level_index].write_message(message, current_level)
}

fn level_index(&self, level: u32) -> usize {
level as usize % self.levels.len()
}
}

impl Outbox<Normal> {
/// Read outbox messages at the given level
///
/// Warning: The caller must ensure that `level` is within the outbox
/// validity window
#[cfg_attr(not(test), expect(dead_code, reason = "outbox not in use"))]
pub(crate) fn read_level(&self, level: u32) -> Box<[Box<[u8]>]> {
let level_index = self.level_index(level);
self.levels[level_index].read_level(level)
}
}

impl<'normal> Provable<'normal> for Outbox<Normal> {
Expand Down Expand Up @@ -197,13 +213,13 @@ impl<M: AtomMode> OutboxLevel<M> {
message: OutboxMessage,
current_level: u32,
) -> Result<(), OutboxError> {
let previous_level = self.level.read();
let last_written_level = self.level.read();
assert!(
current_level >= previous_level,
"current_level {current_level} must be gte to any level already stored in the outbox. Found {previous_level}"
current_level >= last_written_level,
"current_level {current_level} must be gte to any level already stored in the outbox. Found {last_written_level}"
);

if current_level > previous_level {
if current_level > last_written_level {
self.next_index.write(0);
self.level.write(current_level);
}
Expand All @@ -220,6 +236,27 @@ impl<M: AtomMode> OutboxLevel<M> {
}
}

impl OutboxLevel<Normal> {
fn read_level(&self, level: u32) -> Box<[Box<[u8]>]> {
let last_written_level = self.level.read();
debug_assert!(
level >= last_written_level,
"level {level} must be gte to the last written level for this outbox level slot. Found {last_written_level}"
);

let next_index = self.next_index.read() as usize;
if level != last_written_level || next_index == 0 {
// The outbox is empty for `level`
return Box::new([]);
}

self.messages[..next_index]
.iter()
.map(|msg| Box::from(msg.as_ref()))
.collect::<Box<[_]>>()
}
}

impl<'normal> Provable<'normal> for OutboxLevel<Normal> {
type Prover = OutboxLevel<Prove<'normal>>;

Expand Down Expand Up @@ -359,19 +396,39 @@ impl DerefMut for OutboxMessage {

#[cfg(test)]
mod tests {
use std::ops::Bound::*;
use std::ops::RangeBounds;
use std::ops::RangeInclusive;

use itertools::Itertools;
use proptest::prelude::*;

use super::*;

fn message_strategy(size_range: RangeInclusive<usize>) -> impl Strategy<Value = OutboxMessage> {
proptest::collection::vec(any::<u8>(), size_range)
fn safe_size_range(size_range: impl RangeBounds<usize>) -> RangeInclusive<usize> {
let start_bound = match size_range.start_bound() {
Included(s) => *s,
Excluded(s) => *s + 1,
Unbounded => 0,
};

match size_range.end_bound() {
Included(end) => start_bound..=MAX_OUTPUT_SIZE.min(*end),
Excluded(end) => start_bound..=MAX_OUTPUT_SIZE.min(end.saturating_sub(1)),
Unbounded => start_bound..=MAX_OUTPUT_SIZE,
}
}

fn message_strategy(
size_range: impl RangeBounds<usize>,
) -> impl Strategy<Value = OutboxMessage> {
let safe_range = safe_size_range(size_range);
proptest::collection::vec(any::<u8>(), safe_range)
.prop_map(|data| OutboxMessage(data.into_boxed_slice()))
}

fn messages_strategy(
size_range: RangeInclusive<usize>,
size_range: impl RangeBounds<usize>,
len: usize,
) -> impl Strategy<Value = Vec<OutboxMessage>> {
proptest::collection::vec(message_strategy(size_range), len)
Expand All @@ -380,7 +437,7 @@ mod tests {
#[test]
fn write_messages_with_varying_sizes() {
proptest!(|(
messages in messages_strategy(0..=MAX_OUTPUT_SIZE, MAX_LEVEL_SIZE),
messages in messages_strategy(0.., MAX_LEVEL_SIZE),
level in 0u32..TEST_OUTBOX_SIZE as u32
)| {
let mut outbox = Outbox::<Normal>::default();
Expand Down Expand Up @@ -448,4 +505,58 @@ mod tests {
assert!(matches!(res, Err(OutboxError::OutboxMessageTooLarge { size: s }) if s == size));
})
}

#[test]
fn read_level_after_write() {
proptest!(|(
messages in proptest::collection::vec(message_strategy(1..=MAX_OUTPUT_SIZE), 1..MAX_LEVEL_SIZE),
level in 0u32..1000
)| {
let mut outbox = Outbox::<Normal>::default();
for msg in &messages {
prop_assert!(outbox.write_message(msg.clone(), level).is_ok());
}

let read_messages = outbox.read_level(level);

prop_assert_eq!(read_messages.len(), messages.len());
for (i, msg) in messages.iter().enumerate() {
prop_assert_eq!(read_messages[i].as_ref(), msg.as_ref() as &[u8]);
}
});
}

#[test]
fn read_overwritten_slot_returns_new_level_data() {
proptest!(|(
messages1 in proptest::collection::vec(messages_strategy(0..=32, 50), TEST_OUTBOX_SIZE),
messages2 in proptest::collection::vec(messages_strategy(0..=16, 10), TEST_OUTBOX_SIZE)
)|{
let mut outbox = Outbox::<Normal>::default();
for (level, msgs) in messages1.iter().enumerate() {
for msg in msgs {
prop_assert!(outbox.write_message(msg.clone(), level as u32).is_ok());
}
}

for (offset, msgs) in messages2.iter().enumerate() {
let wrap_level = TEST_OUTBOX_SIZE + offset;
for msg in msgs {
prop_assert!(outbox.write_message(msg.clone(), wrap_level as u32).is_ok());
}
let read_messages = outbox.read_level(wrap_level as u32);
let expected_messages: Box<[Box<[u8]>]> = Box::from(messages2[offset].clone().into_iter().map(|m|m.0).collect_vec());
prop_assert_eq!(read_messages, expected_messages);
}
});
}

#[test]
fn read_fresh_outbox_is_empty() {
proptest!(|(level in 0u32..TEST_OUTBOX_SIZE as u32)| {
let outbox = Outbox::<Normal>::default();
let result = outbox.read_level(level);
prop_assert_eq!(result.len(), 0)
});
}
}
Loading