@@ -41,7 +41,9 @@ pub const TIMESTAMPED_BLOCK_WIDTH: usize = 11;
4141/// wait for the completion.
4242pub struct MemoryMerkleSubTree {
4343 pub stream : Arc < CudaStream > ,
44- pub event : Option < CudaEvent > ,
44+ #[ allow( dead_code) ] // See `drop_subtrees`
45+ created_buffer_event : Option < CudaEvent > ,
46+ build_completion_event : Option < CudaEvent > ,
4547 pub buf : DeviceBuffer < H > ,
4648 pub height : usize ,
4749 pub path_len : usize ,
@@ -83,7 +85,8 @@ impl MemoryMerkleSubTree {
8385 stream. wait ( & created_buffer_event) . unwrap ( ) ;
8486 Self {
8587 stream,
86- event : None ,
88+ created_buffer_event : Some ( created_buffer_event) ,
89+ build_completion_event : None ,
8790 height,
8891 buf,
8992 path_len,
@@ -93,7 +96,8 @@ impl MemoryMerkleSubTree {
9396 pub fn dummy ( ) -> Self {
9497 Self {
9598 stream : Arc :: new ( CudaStream :: new ( ) . unwrap ( ) ) ,
96- event : None ,
99+ created_buffer_event : None ,
100+ build_completion_event : None ,
97101 height : 0 ,
98102 buf : DeviceBuffer :: new ( ) ,
99103 path_len : 0 ,
@@ -148,7 +152,7 @@ impl MemoryMerkleSubTree {
148152 event. record ( self . stream . as_raw ( ) ) . unwrap ( ) ;
149153 }
150154 }
151- self . event = Some ( event) ;
155+ self . build_completion_event = Some ( event) ;
152156 }
153157
154158 /// Returns the bounds [start, end) of the layer at the given depth.
@@ -280,7 +284,7 @@ impl MemoryMerkleTree {
280284 for subtree in self . subtrees . iter ( ) {
281285 default_stream_wait (
282286 subtree
283- . event
287+ . build_completion_event
284288 . as_ref ( )
285289 . expect ( "Subtree build event does not exist" ) ,
286290 )
@@ -314,15 +318,19 @@ impl MemoryMerkleTree {
314318 /// synchronization like D2H transfer).
315319 pub fn drop_subtrees ( & mut self ) {
316320 let mut needs_sync = false ;
321+ // Make sure all streams are synchronized before destroying events
317322 for subtree in self . subtrees . iter ( ) {
318- if let Some ( event) = subtree. event . as_ref ( ) {
323+ subtree. stream . synchronize ( ) . unwrap ( ) ;
324+ if let Some ( event) = subtree. build_completion_event . as_ref ( ) {
319325 needs_sync = true ;
320326 default_stream_wait ( event) . unwrap ( ) ;
321327 }
322328 }
323329 if needs_sync {
324330 current_stream_sync ( ) . unwrap ( ) ;
325331 }
332+ // Clearing will drop streams (which calls synchronize again) and drops events (which
333+ // destroys them)
326334 self . subtrees . clear ( ) ;
327335 }
328336
@@ -335,9 +343,9 @@ impl MemoryMerkleTree {
335343 ) -> AirProvingContext < GpuBackend > {
336344 let mut public_values = self . top_roots . to_host ( ) . unwrap ( ) [ 0 ] . to_vec ( ) ;
337345 // .to_host() calls cudaEventSynchronize on the D2H memcpy, which also means all subtree
338- // events are now completed, so we can clear up the events.
346+ // events are now completed, so we can clean up the events.
339347 for subtree in & mut self . subtrees {
340- subtree. event = None ;
348+ subtree. build_completion_event = None ;
341349 }
342350 let merkle_trace = {
343351 let width = MemoryMerkleCols :: < u8 , DIGEST_WIDTH > :: width ( ) ;
0 commit comments