diff --git a/src/riscv/lib/src/pvm/node_pvm.rs b/src/riscv/lib/src/pvm/node_pvm.rs index 93cd8f410b..a24bf038af 100644 --- a/src/riscv/lib/src/pvm/node_pvm.rs +++ b/src/riscv/lib/src/pvm/node_pvm.rs @@ -203,6 +203,10 @@ impl> NodePvm { let proof_state = self.state.start_proof(); proof_state.produce_outbox_proof(output) } + + pub fn get_outbox_messages(&self, level: u32) -> Vec { + self.with_backend(|pvm| pvm.get_outbox_messages(level).unwrap_or_default()) + } } impl NodePvm { diff --git a/src/riscv/lib/src/pvm/outbox.rs b/src/riscv/lib/src/pvm/outbox.rs index 3d1240fbe1..41ade79ae6 100644 --- a/src/riscv/lib/src/pvm/outbox.rs +++ b/src/riscv/lib/src/pvm/outbox.rs @@ -185,7 +185,6 @@ impl Outbox { /// /// Warning: The caller must ensure that `level` is within the outbox /// validity window - #[cfg_attr(not(test), expect(dead_code, reason = "outbox not in use"))] pub(crate) fn read_level(&self, level: u32) -> Box<[Box<[u8]>]> { let level_index = self.level_index(level); self.levels[level_index].read_level(level) @@ -298,6 +297,32 @@ impl, M: Mode> Pvm { } } +impl> Pvm { + /// Retrieves the outbox messages for a given level. Returns None if the level is + /// not in the outbox + pub(crate) fn get_outbox_messages(&self, level: u32) -> Option> { + self.check_level_in_outbox(level).ok()?; + + let result = self + .outbox + .read_level(level) + .into_iter() + .enumerate() + .map(|(i, msg)| Output { + // Message length was checked when it was written. No need to + // check again + message: OutboxMessage(msg), + info: OutputInfo { + level, + index: i as u32, + }, + }) + .collect(); + + Some(result) + } +} + #[perfect_derive(Clone, PartialEq, Eq)] struct OutboxLevel { messages: Box<[Atom, M>]>, @@ -889,4 +914,98 @@ mod tests { } }); } + + #[test] + fn get_outbox_messages_returns_all_messages_at_level() { + proptest!(|( + messages in messages_strategy(0.., 5), + level in 0u32..1000 + )| { + type MC = M1M; + type PC = EmptyPageCache; + + let mut pvm = Pvm::::default(); + + pvm.level_is_set.write(true); + pvm.level.write(level); + + for message in &messages { + prop_assert!(pvm.outbox.write_message(message.clone(), level).is_ok()); + } + + let result = pvm.get_outbox_messages(level); + let outputs = result.unwrap(); + + prop_assert_eq!(outputs.len(), messages.len()); + for (i, (output, message)) in outputs.iter().zip(messages.iter()).enumerate() { + prop_assert_eq!(&*output.message, &**message); + prop_assert_eq!(output.info.level, level); + prop_assert_eq!(output.info.index, i as u32); + } + }); + } + + #[test] + fn get_outbox_messages_returns_none_for_uninitialised_pvm() { + proptest!(|(level in 0u32..1000)| { + type MC = M1M; + type PC = EmptyPageCache; + + let pvm = Pvm::::default(); + prop_assert!(pvm.get_outbox_messages(level).is_none()); + }); + } + + #[test] + fn get_outbox_messages_returns_none_for_future_level() { + proptest!(|( + messages in messages_strategy(0.., 5), + level in 0u32..1000 + )| { + type MC = M1M; + type PC = EmptyPageCache; + + let mut pvm = Pvm::::default(); + + pvm.level_is_set.write(true); + pvm.level.write(level); + + for message in &messages { + prop_assert!(pvm.outbox.write_message(message.clone(), level).is_ok()); + } + + prop_assert!(pvm.get_outbox_messages(level).is_some()); + + // Query a level beyond the current level + prop_assert!(pvm.get_outbox_messages(level + 1).is_none()); + }); + } + + #[test] + fn get_outbox_messages_returns_none_for_expired_level() { + proptest!(|( + messages in messages_strategy(0.., 5), + level in TEST_OUTBOX_SIZE as u32..1000 + )| { + type MC = M1M; + type PC = EmptyPageCache; + + let mut pvm = Pvm::::default(); + pvm.level_is_set.write(true); + pvm.level.write(level); + + for message in &messages { + prop_assert!(pvm.outbox.write_message(message.clone(), level).is_ok()); + } + + prop_assert!(pvm.get_outbox_messages(level).is_some()); + + // Advance the PVM level so write_level has expired + let current_level = level + TEST_OUTBOX_SIZE as u32 + 1; + pvm.level_is_set.write(true); + pvm.level.write(current_level); + + prop_assert!(pvm.get_outbox_messages(level).is_none()); + }); + } }