@@ -24,6 +24,7 @@ pub fn Pool(comptime Node: type) type {
24
24
head : ? * Node = null ,
25
25
26
26
// Tracks chunks of allocated nodes, used for freeing them at deinit() time.
27
+ cleanup_mu : std.Thread.Mutex = .{},
27
28
cleanup : std .ArrayListUnmanaged ([* ]Node ) = .{},
28
29
29
30
// How many nodes to allocate at once for each chunk in the pool.
@@ -72,10 +73,25 @@ pub fn Pool(comptime Node: type) type {
72
73
break ; // Pool is empty
73
74
}
74
75
75
- // Pool is empty, allocate new chunk of nodes, and track the pointer for later cleanup
76
+ // Pool is empty, we need to allocate new nodes
77
+ // This is the rare path where we need a lock to ensure thread safety only for the
78
+ // pool.cleanup tracking list.
79
+ pool .cleanup_mu .lock ();
80
+
81
+ // Check the pool again after acquiring the lock
82
+ // Another thread might have already allocated nodes while we were waiting
83
+ const head2 = @atomicLoad (? * Node , & pool .head , .acquire );
84
+ if (head2 ) | _ | {
85
+ // Pool is no longer empty, release the lock and try to acquire a node again
86
+ pool .cleanup_mu .unlock ();
87
+ return pool .acquire (allocator );
88
+ }
89
+
90
+ // Pool still empty, allocate new chunk of nodes, and track the pointer for later cleanup
76
91
const new_nodes = try allocator .alloc (Node , pool .chunk_size );
77
92
errdefer allocator .free (new_nodes );
78
93
try pool .cleanup .append (allocator , @ptrCast (new_nodes .ptr ));
94
+ pool .cleanup_mu .unlock ();
79
95
80
96
// Link all our new nodes (except the first one acquired by the caller) into a chain
81
97
// with eachother.
@@ -311,3 +327,43 @@ test "basic" {
311
327
try std .testing .expectEqual (queue .pop (), 3 );
312
328
try std .testing .expectEqual (queue .pop (), null );
313
329
}
330
+
331
+ test "concurrent producers" {
332
+ const allocator = std .testing .allocator ;
333
+
334
+ var queue : Queue (u32 ) = undefined ;
335
+ try queue .init (allocator , 32 );
336
+ defer queue .deinit (allocator );
337
+
338
+ const n_jobs = 100 ;
339
+ const n_entries : u32 = 10000 ;
340
+
341
+ var pool : std.Thread.Pool = undefined ;
342
+ try std .Thread .Pool .init (& pool , .{ .allocator = allocator , .n_jobs = n_jobs });
343
+ defer pool .deinit ();
344
+
345
+ var wg : std.Thread.WaitGroup = .{};
346
+ for (0.. n_jobs ) | _ | {
347
+ pool .spawnWg (
348
+ & wg ,
349
+ struct {
350
+ pub fn run (q : * Queue (u32 )) void {
351
+ var i : u32 = 0 ;
352
+ while (i < n_entries ) : (i += 1 ) {
353
+ q .push (allocator , i ) catch unreachable ;
354
+ }
355
+ }
356
+ }.run ,
357
+ .{& queue },
358
+ );
359
+ }
360
+
361
+ wg .wait ();
362
+
363
+ // Verify we can read some values without crashing
364
+ var count : usize = 0 ;
365
+ while (queue .pop ()) | _ | {
366
+ count += 1 ;
367
+ if (count >= n_jobs * n_entries ) break ;
368
+ }
369
+ }
0 commit comments