Skip to content

Commit 8e48fbb

Browse files
fix(cuda): ensure streams have observed events before destroying (#2242)
- MemoryMerkleTree was using a very strange pattern with its own stream passing to/from default stream. I don't think this is necessary since most of the work is done on subtree streams and the finalize kernel can be on the default stream to simplify things. - Went through all places where an event was dropped (which destroys it) before the event is actually awaited on the stream. For subtree streams I just made them all synchronize since those streams need to be completed anyways. I did a small optimization to avoid another synchronize on the default stream (perhaps unnecessary) where after a D2H transfer, I remove the events that must have been observed on all streams. Comparison to show there's no perf regression: https://github.com/axiom-crypto/openvm-reth-benchmark/actions/runs/19352734146
1 parent 432dada commit 8e48fbb

File tree

2 files changed

+55
-35
lines changed

2 files changed

+55
-35
lines changed

crates/vm/src/system/cuda/memory.rs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -203,11 +203,12 @@ impl MemoryInventoryGPU {
203203
}
204204
mem.tracing_info("merkle update");
205205
persistent.merkle_tree.finalize();
206-
Some(persistent.merkle_tree.update_with_touched_blocks(
206+
let merkle_tree_ctx = persistent.merkle_tree.update_with_touched_blocks(
207207
unpadded_merkle_height,
208208
&d_touched_memory,
209209
empty,
210-
))
210+
);
211+
Some(merkle_tree_ctx)
211212
}
212213
TouchedMemory::Volatile(partition) => {
213214
assert!(self.persistent.is_none(), "TouchedMemory enum mismatch");
@@ -234,12 +235,8 @@ impl MemoryInventoryGPU {
234235

235236
impl Drop for PersistentMemoryInventoryGPU {
236237
fn drop(&mut self) {
237-
// Force synchronize all streams in merkle tree before dropping the
238-
// initial memory buffers. This prevents buffers from dropping before build_async completes.
239-
for s in &self.merkle_tree.subtrees {
240-
s.stream.synchronize().unwrap();
241-
}
242-
self.merkle_tree.stream.synchronize().unwrap();
238+
// WARNING: The merkle subtree events must be completed before dropping the initial memory
239+
// buffers. This prevents buffers from dropping before build_async completes.
243240
self.merkle_tree.drop_subtrees();
244241
self.initial_memory.clear();
245242
}

crates/vm/src/system/cuda/merkle_tree/mod.rs

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ use openvm_cuda_backend::{base::DeviceMatrix, prelude::F, prover_backend::GpuBac
99
use openvm_cuda_common::{
1010
copy::{cuda_memcpy, MemCopyD2H, MemCopyH2D},
1111
d_buffer::DeviceBuffer,
12-
stream::{cudaStreamPerThread, default_stream_wait, CudaEvent, CudaStream},
12+
stream::{
13+
cudaStreamPerThread, current_stream_sync, default_stream_wait, CudaEvent, CudaStream,
14+
},
1315
};
1416
use openvm_stark_backend::{
1517
p3_maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator},
@@ -39,7 +41,9 @@ pub const TIMESTAMPED_BLOCK_WIDTH: usize = 11;
3941
/// wait for the completion.
4042
pub struct MemoryMerkleSubTree {
4143
pub stream: Arc<CudaStream>,
42-
pub event: Option<CudaEvent>,
44+
#[allow(dead_code)] // See `drop_subtrees`
45+
created_buffer_event: Option<CudaEvent>,
46+
build_completion_event: Option<CudaEvent>,
4347
pub buf: DeviceBuffer<H>,
4448
pub height: usize,
4549
pub path_len: usize,
@@ -81,7 +85,8 @@ impl MemoryMerkleSubTree {
8185
stream.wait(&created_buffer_event).unwrap();
8286
Self {
8387
stream,
84-
event: None,
88+
created_buffer_event: Some(created_buffer_event),
89+
build_completion_event: None,
8590
height,
8691
buf,
8792
path_len,
@@ -91,7 +96,8 @@ impl MemoryMerkleSubTree {
9196
pub fn dummy() -> Self {
9297
Self {
9398
stream: Arc::new(CudaStream::new().unwrap()),
94-
event: None,
99+
created_buffer_event: None,
100+
build_completion_event: None,
95101
height: 0,
96102
buf: DeviceBuffer::new(),
97103
path_len: 0,
@@ -146,7 +152,7 @@ impl MemoryMerkleSubTree {
146152
event.record(self.stream.as_raw()).unwrap();
147153
}
148154
}
149-
self.event = Some(event);
155+
self.build_completion_event = Some(event);
150156
}
151157

152158
/// Returns the bounds [start, end) of the layer at the given depth.
@@ -183,10 +189,9 @@ impl MemoryMerkleSubTree {
183189
///
184190
/// Execution:
185191
/// - Subtrees are built asynchronously on individual CUDA streams.
186-
/// - The final root is computed after all subtrees complete, on a shared stream.
192+
/// - The final root is computed after all subtrees complete, on the default stream.
187193
/// - `CudaEvent`s are used to synchronize subtree completion.
188194
pub struct MemoryMerkleTree {
189-
pub stream: Arc<CudaStream>,
190195
pub subtrees: Vec<MemoryMerkleSubTree>,
191196
pub top_roots: DeviceBuffer<H>,
192197
zero_hash: DeviceBuffer<H>,
@@ -234,7 +239,6 @@ impl MemoryMerkleTree {
234239
}
235240

236241
Self {
237-
stream: Arc::new(CudaStream::new().unwrap()),
238242
subtrees: Vec::new(),
239243
top_roots,
240244
height: label_max_bits + log2_ceil_usize(num_addr_spaces),
@@ -273,47 +277,55 @@ impl MemoryMerkleTree {
273277

274278
/// Finalizes the Merkle tree by collecting all subtree roots and computing the final root.
275279
/// Waits for all subtrees to complete and then performs the final hash operation.
276-
pub fn finalize(&self) {
280+
pub fn finalize(&mut self) {
281+
// Default stream waits for all subtrees to complete
277282
for subtree in self.subtrees.iter() {
278-
self.stream.wait(subtree.event.as_ref().unwrap()).unwrap();
279-
}
280-
281-
let we_can_gather_bufs_event = CudaEvent::new().unwrap();
282-
unsafe {
283-
we_can_gather_bufs_event
284-
.record(self.stream.as_raw())
285-
.unwrap();
283+
default_stream_wait(
284+
subtree
285+
.build_completion_event
286+
.as_ref()
287+
.expect("Subtree build event does not exist"),
288+
)
289+
.unwrap();
286290
}
287-
default_stream_wait(&we_can_gather_bufs_event).unwrap();
288291

289292
let roots: Vec<usize> = self
290293
.subtrees
291294
.iter()
292295
.map(|subtree| subtree.buf.as_ptr() as usize)
293296
.collect();
294297
let d_roots = roots.to_device().unwrap();
295-
let to_device_event = CudaEvent::new().unwrap();
296-
unsafe {
297-
to_device_event.record(cudaStreamPerThread).unwrap();
298-
}
299-
self.stream.wait(&to_device_event).unwrap();
300298

301299
unsafe {
302300
finalize_merkle_tree(
303301
&d_roots,
304302
&self.top_roots,
305303
self.subtrees.len(),
306-
self.stream.as_raw(),
304+
cudaStreamPerThread,
307305
)
308306
.unwrap();
309307
}
310-
311-
self.stream.synchronize().unwrap();
312308
}
313309

314310
/// Drops all massive buffers to free memory. Used at the end of an execution segment.
311+
///
312+
/// Caution: this method destroys all subtree streams and events. For safety, we force
313+
/// synchronize all subtree streams and the default stream (cudaStreamPerThread) with host
314+
/// before deallocating buffers.
315315
pub fn drop_subtrees(&mut self) {
316-
self.subtrees = Vec::new();
316+
// Make sure all streams are synchronized before destroying events
317+
for subtree in self.subtrees.iter() {
318+
subtree.stream.synchronize().unwrap();
319+
if let Some(_event) = subtree.build_completion_event.as_ref() {
320+
tracing::warn!(
321+
"Dropping merkle subtree before build_async event has been destroyed"
322+
);
323+
}
324+
}
325+
current_stream_sync().unwrap();
326+
// Clearing will drop streams (which calls synchronize again) and drops events (which
327+
// destroys them)
328+
self.subtrees.clear();
317329
}
318330

319331
/// Updates the tree and returns the merkle trace.
@@ -324,6 +336,11 @@ impl MemoryMerkleTree {
324336
empty_touched_blocks: bool,
325337
) -> AirProvingContext<GpuBackend> {
326338
let mut public_values = self.top_roots.to_host().unwrap()[0].to_vec();
339+
// .to_host() calls cudaEventSynchronize on the D2H memcpy, which also means all subtree
340+
// events are now completed, so we can clean up the events.
341+
for subtree in &mut self.subtrees {
342+
subtree.build_completion_event = None;
343+
}
327344
let merkle_trace = {
328345
let width = MemoryMerkleCols::<u8, DIGEST_WIDTH>::width();
329346
let padded_height = next_power_of_two_or_zero(unpadded_height);
@@ -397,6 +414,12 @@ impl MemoryMerkleTree {
397414
}
398415
}
399416

417+
impl Drop for MemoryMerkleTree {
418+
fn drop(&mut self) {
419+
self.drop_subtrees();
420+
}
421+
}
422+
400423
#[cfg(test)]
401424
mod tests {
402425
use std::sync::Arc;

0 commit comments

Comments
 (0)