@@ -9,7 +9,9 @@ use openvm_cuda_backend::{base::DeviceMatrix, prelude::F, prover_backend::GpuBac
99use openvm_cuda_common:: {
1010 copy:: { cuda_memcpy, MemCopyD2H , MemCopyH2D } ,
1111 d_buffer:: DeviceBuffer ,
12- stream:: { cudaStreamPerThread, default_stream_wait, CudaEvent , CudaStream } ,
12+ stream:: {
13+ cudaStreamPerThread, current_stream_sync, default_stream_wait, CudaEvent , CudaStream ,
14+ } ,
1315} ;
1416use openvm_stark_backend:: {
1517 p3_maybe_rayon:: prelude:: { IntoParallelIterator , ParallelIterator } ,
@@ -39,7 +41,9 @@ pub const TIMESTAMPED_BLOCK_WIDTH: usize = 11;
3941/// wait for the completion.
4042pub struct MemoryMerkleSubTree {
4143 pub stream : Arc < CudaStream > ,
42- pub event : Option < CudaEvent > ,
44+ #[ allow( dead_code) ] // See `drop_subtrees`
45+ created_buffer_event : Option < CudaEvent > ,
46+ build_completion_event : Option < CudaEvent > ,
4347 pub buf : DeviceBuffer < H > ,
4448 pub height : usize ,
4549 pub path_len : usize ,
@@ -81,7 +85,8 @@ impl MemoryMerkleSubTree {
8185 stream. wait ( & created_buffer_event) . unwrap ( ) ;
8286 Self {
8387 stream,
84- event : None ,
88+ created_buffer_event : Some ( created_buffer_event) ,
89+ build_completion_event : None ,
8590 height,
8691 buf,
8792 path_len,
@@ -91,7 +96,8 @@ impl MemoryMerkleSubTree {
9196 pub fn dummy ( ) -> Self {
9297 Self {
9398 stream : Arc :: new ( CudaStream :: new ( ) . unwrap ( ) ) ,
94- event : None ,
99+ created_buffer_event : None ,
100+ build_completion_event : None ,
95101 height : 0 ,
96102 buf : DeviceBuffer :: new ( ) ,
97103 path_len : 0 ,
@@ -146,7 +152,7 @@ impl MemoryMerkleSubTree {
146152 event. record ( self . stream . as_raw ( ) ) . unwrap ( ) ;
147153 }
148154 }
149- self . event = Some ( event) ;
155+ self . build_completion_event = Some ( event) ;
150156 }
151157
152158 /// Returns the bounds [start, end) of the layer at the given depth.
@@ -183,10 +189,9 @@ impl MemoryMerkleSubTree {
183189///
184190/// Execution:
185191/// - Subtrees are built asynchronously on individual CUDA streams.
186- /// - The final root is computed after all subtrees complete, on a shared stream.
192+ /// - The final root is computed after all subtrees complete, on the default stream.
187193/// - `CudaEvent`s are used to synchronize subtree completion.
188194pub struct MemoryMerkleTree {
189- pub stream : Arc < CudaStream > ,
190195 pub subtrees : Vec < MemoryMerkleSubTree > ,
191196 pub top_roots : DeviceBuffer < H > ,
192197 zero_hash : DeviceBuffer < H > ,
@@ -234,7 +239,6 @@ impl MemoryMerkleTree {
234239 }
235240
236241 Self {
237- stream : Arc :: new ( CudaStream :: new ( ) . unwrap ( ) ) ,
238242 subtrees : Vec :: new ( ) ,
239243 top_roots,
240244 height : label_max_bits + log2_ceil_usize ( num_addr_spaces) ,
@@ -273,47 +277,55 @@ impl MemoryMerkleTree {
273277
274278 /// Finalizes the Merkle tree by collecting all subtree roots and computing the final root.
275279 /// Waits for all subtrees to complete and then performs the final hash operation.
276- pub fn finalize ( & self ) {
280+ pub fn finalize ( & mut self ) {
281+ // Default stream waits for all subtrees to complete
277282 for subtree in self . subtrees . iter ( ) {
278- self . stream . wait ( subtree. event . as_ref ( ) . unwrap ( ) ) . unwrap ( ) ;
279- }
280-
281- let we_can_gather_bufs_event = CudaEvent :: new ( ) . unwrap ( ) ;
282- unsafe {
283- we_can_gather_bufs_event
284- . record ( self . stream . as_raw ( ) )
285- . unwrap ( ) ;
283+ default_stream_wait (
284+ subtree
285+ . build_completion_event
286+ . as_ref ( )
287+ . expect ( "Subtree build event does not exist" ) ,
288+ )
289+ . unwrap ( ) ;
286290 }
287- default_stream_wait ( & we_can_gather_bufs_event) . unwrap ( ) ;
288291
289292 let roots: Vec < usize > = self
290293 . subtrees
291294 . iter ( )
292295 . map ( |subtree| subtree. buf . as_ptr ( ) as usize )
293296 . collect ( ) ;
294297 let d_roots = roots. to_device ( ) . unwrap ( ) ;
295- let to_device_event = CudaEvent :: new ( ) . unwrap ( ) ;
296- unsafe {
297- to_device_event. record ( cudaStreamPerThread) . unwrap ( ) ;
298- }
299- self . stream . wait ( & to_device_event) . unwrap ( ) ;
300298
301299 unsafe {
302300 finalize_merkle_tree (
303301 & d_roots,
304302 & self . top_roots ,
305303 self . subtrees . len ( ) ,
306- self . stream . as_raw ( ) ,
304+ cudaStreamPerThread ,
307305 )
308306 . unwrap ( ) ;
309307 }
310-
311- self . stream . synchronize ( ) . unwrap ( ) ;
312308 }
313309
314310 /// Drops all massive buffers to free memory. Used at the end of an execution segment.
311+ ///
312+ /// Caution: this method destroys all subtree streams and events. For safety, we force
313+ /// synchronize all subtree streams and the default stream (cudaStreamPerThread) with host
314+ /// before deallocating buffers.
315315 pub fn drop_subtrees ( & mut self ) {
316- self . subtrees = Vec :: new ( ) ;
316+ // Make sure all streams are synchronized before destroying events
317+ for subtree in self . subtrees . iter ( ) {
318+ subtree. stream . synchronize ( ) . unwrap ( ) ;
319+ if let Some ( _event) = subtree. build_completion_event . as_ref ( ) {
320+ tracing:: warn!(
321+ "Dropping merkle subtree before build_async event has been destroyed"
322+ ) ;
323+ }
324+ }
325+ current_stream_sync ( ) . unwrap ( ) ;
326+ // Clearing will drop streams (which calls synchronize again) and drops events (which
327+ // destroys them)
328+ self . subtrees . clear ( ) ;
317329 }
318330
319331 /// Updates the tree and returns the merkle trace.
@@ -324,6 +336,11 @@ impl MemoryMerkleTree {
324336 empty_touched_blocks : bool ,
325337 ) -> AirProvingContext < GpuBackend > {
326338 let mut public_values = self . top_roots . to_host ( ) . unwrap ( ) [ 0 ] . to_vec ( ) ;
339+ // .to_host() calls cudaEventSynchronize on the D2H memcpy, which also means all subtree
340+ // events are now completed, so we can clean up the events.
341+ for subtree in & mut self . subtrees {
342+ subtree. build_completion_event = None ;
343+ }
327344 let merkle_trace = {
328345 let width = MemoryMerkleCols :: < u8 , DIGEST_WIDTH > :: width ( ) ;
329346 let padded_height = next_power_of_two_or_zero ( unpadded_height) ;
@@ -397,6 +414,12 @@ impl MemoryMerkleTree {
397414 }
398415}
399416
417+ impl Drop for MemoryMerkleTree {
418+ fn drop ( & mut self ) {
419+ self . drop_subtrees ( ) ;
420+ }
421+ }
422+
400423#[ cfg( test) ]
401424mod tests {
402425 use std:: sync:: Arc ;
0 commit comments