diff --git a/crates/vm/src/system/cuda/memory.rs b/crates/vm/src/system/cuda/memory.rs index 3e1ca58000..51d7b3677e 100644 --- a/crates/vm/src/system/cuda/memory.rs +++ b/crates/vm/src/system/cuda/memory.rs @@ -203,11 +203,12 @@ impl MemoryInventoryGPU { } mem.tracing_info("merkle update"); persistent.merkle_tree.finalize(); - Some(persistent.merkle_tree.update_with_touched_blocks( + let merkle_tree_ctx = persistent.merkle_tree.update_with_touched_blocks( unpadded_merkle_height, &d_touched_memory, empty, - )) + ); + Some(merkle_tree_ctx) } TouchedMemory::Volatile(partition) => { assert!(self.persistent.is_none(), "TouchedMemory enum mismatch"); @@ -234,12 +235,8 @@ impl MemoryInventoryGPU { impl Drop for PersistentMemoryInventoryGPU { fn drop(&mut self) { - // Force synchronize all streams in merkle tree before dropping the - // initial memory buffers. This prevents buffers from dropping before build_async completes. - for s in &self.merkle_tree.subtrees { - s.stream.synchronize().unwrap(); - } - self.merkle_tree.stream.synchronize().unwrap(); + // WARNING: The merkle subtree events must be completed before dropping the initial memory + // buffers. This prevents buffers from dropping before build_async completes. self.merkle_tree.drop_subtrees(); self.initial_memory.clear(); } diff --git a/crates/vm/src/system/cuda/merkle_tree/mod.rs b/crates/vm/src/system/cuda/merkle_tree/mod.rs index e835169a0d..ec0a91de76 100644 --- a/crates/vm/src/system/cuda/merkle_tree/mod.rs +++ b/crates/vm/src/system/cuda/merkle_tree/mod.rs @@ -9,7 +9,9 @@ use openvm_cuda_backend::{base::DeviceMatrix, prelude::F, prover_backend::GpuBac use openvm_cuda_common::{ copy::{cuda_memcpy, MemCopyD2H, MemCopyH2D}, d_buffer::DeviceBuffer, - stream::{cudaStreamPerThread, default_stream_wait, CudaEvent, CudaStream}, + stream::{ + cudaStreamPerThread, current_stream_sync, default_stream_wait, CudaEvent, CudaStream, + }, }; use openvm_stark_backend::{ p3_maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator}, @@ -39,7 +41,9 @@ pub const TIMESTAMPED_BLOCK_WIDTH: usize = 11; /// wait for the completion. pub struct MemoryMerkleSubTree { pub stream: Arc, - pub event: Option, + #[allow(dead_code)] // See `drop_subtrees` + created_buffer_event: Option, + build_completion_event: Option, pub buf: DeviceBuffer, pub height: usize, pub path_len: usize, @@ -81,7 +85,8 @@ impl MemoryMerkleSubTree { stream.wait(&created_buffer_event).unwrap(); Self { stream, - event: None, + created_buffer_event: Some(created_buffer_event), + build_completion_event: None, height, buf, path_len, @@ -91,7 +96,8 @@ impl MemoryMerkleSubTree { pub fn dummy() -> Self { Self { stream: Arc::new(CudaStream::new().unwrap()), - event: None, + created_buffer_event: None, + build_completion_event: None, height: 0, buf: DeviceBuffer::new(), path_len: 0, @@ -146,7 +152,7 @@ impl MemoryMerkleSubTree { event.record(self.stream.as_raw()).unwrap(); } } - self.event = Some(event); + self.build_completion_event = Some(event); } /// Returns the bounds [start, end) of the layer at the given depth. @@ -183,10 +189,9 @@ impl MemoryMerkleSubTree { /// /// Execution: /// - Subtrees are built asynchronously on individual CUDA streams. -/// - The final root is computed after all subtrees complete, on a shared stream. +/// - The final root is computed after all subtrees complete, on the default stream. /// - `CudaEvent`s are used to synchronize subtree completion. pub struct MemoryMerkleTree { - pub stream: Arc, pub subtrees: Vec, pub top_roots: DeviceBuffer, zero_hash: DeviceBuffer, @@ -234,7 +239,6 @@ impl MemoryMerkleTree { } Self { - stream: Arc::new(CudaStream::new().unwrap()), subtrees: Vec::new(), top_roots, height: label_max_bits + log2_ceil_usize(num_addr_spaces), @@ -273,18 +277,17 @@ impl MemoryMerkleTree { /// Finalizes the Merkle tree by collecting all subtree roots and computing the final root. /// Waits for all subtrees to complete and then performs the final hash operation. - pub fn finalize(&self) { + pub fn finalize(&mut self) { + // Default stream waits for all subtrees to complete for subtree in self.subtrees.iter() { - self.stream.wait(subtree.event.as_ref().unwrap()).unwrap(); - } - - let we_can_gather_bufs_event = CudaEvent::new().unwrap(); - unsafe { - we_can_gather_bufs_event - .record(self.stream.as_raw()) - .unwrap(); + default_stream_wait( + subtree + .build_completion_event + .as_ref() + .expect("Subtree build event does not exist"), + ) + .unwrap(); } - default_stream_wait(&we_can_gather_bufs_event).unwrap(); let roots: Vec = self .subtrees @@ -292,28 +295,37 @@ impl MemoryMerkleTree { .map(|subtree| subtree.buf.as_ptr() as usize) .collect(); let d_roots = roots.to_device().unwrap(); - let to_device_event = CudaEvent::new().unwrap(); - unsafe { - to_device_event.record(cudaStreamPerThread).unwrap(); - } - self.stream.wait(&to_device_event).unwrap(); unsafe { finalize_merkle_tree( &d_roots, &self.top_roots, self.subtrees.len(), - self.stream.as_raw(), + cudaStreamPerThread, ) .unwrap(); } - - self.stream.synchronize().unwrap(); } /// Drops all massive buffers to free memory. Used at the end of an execution segment. + /// + /// Caution: this method destroys all subtree streams and events. For safety, we force + /// synchronize all subtree streams and the default stream (cudaStreamPerThread) with host + /// before deallocating buffers. pub fn drop_subtrees(&mut self) { - self.subtrees = Vec::new(); + // Make sure all streams are synchronized before destroying events + for subtree in self.subtrees.iter() { + subtree.stream.synchronize().unwrap(); + if let Some(_event) = subtree.build_completion_event.as_ref() { + tracing::warn!( + "Dropping merkle subtree before build_async event has been destroyed" + ); + } + } + current_stream_sync().unwrap(); + // Clearing will drop streams (which calls synchronize again) and drops events (which + // destroys them) + self.subtrees.clear(); } /// Updates the tree and returns the merkle trace. @@ -324,6 +336,11 @@ impl MemoryMerkleTree { empty_touched_blocks: bool, ) -> AirProvingContext { let mut public_values = self.top_roots.to_host().unwrap()[0].to_vec(); + // .to_host() calls cudaEventSynchronize on the D2H memcpy, which also means all subtree + // events are now completed, so we can clean up the events. + for subtree in &mut self.subtrees { + subtree.build_completion_event = None; + } let merkle_trace = { let width = MemoryMerkleCols::::width(); let padded_height = next_power_of_two_or_zero(unpadded_height); @@ -397,6 +414,12 @@ impl MemoryMerkleTree { } } +impl Drop for MemoryMerkleTree { + fn drop(&mut self) { + self.drop_subtrees(); + } +} + #[cfg(test)] mod tests { use std::sync::Arc;