Skip to content

Commit 5919a72

Browse files
committed
fix: do not destroy created_buffer_event until sync
1 parent 0c8f242 commit 5919a72

File tree

1 file changed

+16
-8
lines changed
  • crates/vm/src/system/cuda/merkle_tree

1 file changed

+16
-8
lines changed

crates/vm/src/system/cuda/merkle_tree/mod.rs

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ pub const TIMESTAMPED_BLOCK_WIDTH: usize = 11;
4141
/// wait for the completion.
4242
pub struct MemoryMerkleSubTree {
4343
pub stream: Arc<CudaStream>,
44-
pub event: Option<CudaEvent>,
44+
#[allow(dead_code)] // See `drop_subtrees`
45+
created_buffer_event: Option<CudaEvent>,
46+
build_completion_event: Option<CudaEvent>,
4547
pub buf: DeviceBuffer<H>,
4648
pub height: usize,
4749
pub path_len: usize,
@@ -83,7 +85,8 @@ impl MemoryMerkleSubTree {
8385
stream.wait(&created_buffer_event).unwrap();
8486
Self {
8587
stream,
86-
event: None,
88+
created_buffer_event: Some(created_buffer_event),
89+
build_completion_event: None,
8790
height,
8891
buf,
8992
path_len,
@@ -93,7 +96,8 @@ impl MemoryMerkleSubTree {
9396
pub fn dummy() -> Self {
9497
Self {
9598
stream: Arc::new(CudaStream::new().unwrap()),
96-
event: None,
99+
created_buffer_event: None,
100+
build_completion_event: None,
97101
height: 0,
98102
buf: DeviceBuffer::new(),
99103
path_len: 0,
@@ -148,7 +152,7 @@ impl MemoryMerkleSubTree {
148152
event.record(self.stream.as_raw()).unwrap();
149153
}
150154
}
151-
self.event = Some(event);
155+
self.build_completion_event = Some(event);
152156
}
153157

154158
/// Returns the bounds [start, end) of the layer at the given depth.
@@ -280,7 +284,7 @@ impl MemoryMerkleTree {
280284
for subtree in self.subtrees.iter() {
281285
default_stream_wait(
282286
subtree
283-
.event
287+
.build_completion_event
284288
.as_ref()
285289
.expect("Subtree build event does not exist"),
286290
)
@@ -314,15 +318,19 @@ impl MemoryMerkleTree {
314318
/// synchronization like D2H transfer).
315319
pub fn drop_subtrees(&mut self) {
316320
let mut needs_sync = false;
321+
// Make sure all streams are synchronized before destroying events
317322
for subtree in self.subtrees.iter() {
318-
if let Some(event) = subtree.event.as_ref() {
323+
subtree.stream.synchronize().unwrap();
324+
if let Some(event) = subtree.build_completion_event.as_ref() {
319325
needs_sync = true;
320326
default_stream_wait(event).unwrap();
321327
}
322328
}
323329
if needs_sync {
324330
current_stream_sync().unwrap();
325331
}
332+
// Clearing will drop streams (which calls synchronize again) and drops events (which
333+
// destroys them)
326334
self.subtrees.clear();
327335
}
328336

@@ -335,9 +343,9 @@ impl MemoryMerkleTree {
335343
) -> AirProvingContext<GpuBackend> {
336344
let mut public_values = self.top_roots.to_host().unwrap()[0].to_vec();
337345
// .to_host() calls cudaEventSynchronize on the D2H memcpy, which also means all subtree
338-
// events are now completed, so we can clear up the events.
346+
// events are now completed, so we can clean up the events.
339347
for subtree in &mut self.subtrees {
340-
subtree.event = None;
348+
subtree.build_completion_event = None;
341349
}
342350
let merkle_trace = {
343351
let width = MemoryMerkleCols::<u8, DIGEST_WIDTH>::width();

0 commit comments

Comments
 (0)