Skip to content

Commit 54e1e97

Browse files
vtjnashKristofferC
authored andcommitted
fix waitall deadlock if any errors occur (#60030)
When errors occur, `waitall` may skip allocating Channel producers, leading to deadlock in the subsequent loop in the event that the user asked it to failfast (ironically). This is seen often in the failing of the threads_exec test ever since the test was added for this call. Simplify this to just use separate loops for the wait and the return computation. (cherry picked from commit e2f3178)
1 parent 54749bc commit 54e1e97

File tree

2 files changed

+34
-13
lines changed

2 files changed

+34
-13
lines changed

base/task.jl

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,11 @@ completed tasks, and the other consists of uncompleted tasks.
383383
each runs serially, since this needs to scan the list of `tasks` each time and
384384
synchronize with each one every time this is called. Or consider using
385385
[`waitall(tasks; failfast=true)`](@ref waitall) instead.
386+
387+
!!! compat "Julia 1.12"
388+
This function requires at least Julia 1.12.
386389
"""
387-
waitany(tasks; throw=true) = _wait_multiple(tasks, throw)
390+
waitany(tasks; throw=true) = _wait_multiple(collect_tasks(tasks), throw)
388391

389392
"""
390393
waitall(tasks; failfast=true, throw=true) -> (done_tasks, remaining_tasks)
@@ -400,17 +403,22 @@ given tasks is finished by exception. If `throw` is `true`, throw
400403
401404
The return value consists of two task vectors. The first one consists of
402405
completed tasks, and the other consists of uncompleted tasks.
406+
407+
!!! compat "Julia 1.12"
408+
This function requires at least Julia 1.12.
403409
"""
404-
waitall(tasks; failfast=true, throw=true) = _wait_multiple(tasks, throw, true, failfast)
410+
waitall(tasks; failfast=true, throw=true) = _wait_multiple(collect_tasks(tasks), throw, true, failfast)
405411

406-
function _wait_multiple(waiting_tasks, throwexc=false, all=false, failfast=false)
412+
function collect_tasks(waiting_tasks)
407413
tasks = Task[]
408-
409414
for t in waiting_tasks
410415
t isa Task || error("Expected an iterator of `Task` object")
411416
push!(tasks, t)
412417
end
418+
return tasks
419+
end
413420

421+
function _wait_multiple(tasks::Vector{Task}, throwexc::Bool=false, all::Bool=false, failfast::Bool=false)
414422
if (all && !failfast) || length(tasks) <= 1
415423
exception = false
416424
# Force everything to finish synchronously for the case of waitall
@@ -474,22 +482,36 @@ function _wait_multiple(waiting_tasks, throwexc=false, all=false, failfast=false
474482
end
475483

476484
while nremaining > 0
485+
exception && failfast && break
477486
i = take!(chan)
478487
t = tasks[i]
479488
waiter_tasks[i] = sentinel
480489
done_mask[i] = true
481490
exception |= istaskfailed(t)
482491
nremaining -= 1
483-
484-
# stop early if requested, unless there is something immediately
485-
# ready to consume from the channel (using a race-y check)
486-
if (!all || (failfast && exception)) && !isready(chan)
487-
break
488-
end
492+
# stop early if requested
493+
all || break
489494
end
490495

491496
close(chan)
492497

498+
# now just read which tasks finished directly: the channel is not needed anymore for that
499+
# repeat until we get (acquire) the list of all dependent-exited tasks
500+
changed = true
501+
while changed
502+
changed = false
503+
for (i, done) in enumerate(done_mask)
504+
done && continue
505+
t = tasks[i]
506+
if istaskdone(t)
507+
done_mask[i] = true
508+
exception |= istaskfailed(t)
509+
nremaining -= 1
510+
changed = true
511+
end
512+
end
513+
end
514+
493515
if nremaining == 0
494516
if throwexc && exception
495517
exceptions = [TaskFailedException(t) for t in tasks if istaskfailed(t)]
@@ -500,6 +522,7 @@ function _wait_multiple(waiting_tasks, throwexc=false, all=false, failfast=false
500522
remaining_mask = .~done_mask
501523
for i in findall(remaining_mask)
502524
waiter = waiter_tasks[i]
525+
waiter === sentinel && continue
503526
donenotify = tasks[i].donenotify::ThreadSynchronizer
504527
@lock donenotify list_deletefirst!(donenotify.waitq, waiter)
505528
end

test/threads_exec.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,9 +1372,7 @@ end
13721372
tasks = [Threads.@spawn(div(1, i)) for i = 0:1]
13731373
wait(tasks[1]; throw=false)
13741374
wait(tasks[2]; throw=false)
1375-
@test_throws CompositeException begin
1376-
waitall(Threads.@spawn(div(1, i)) for i = 0:1)
1377-
end
1375+
@test_throws CompositeException waitall(tasks)
13781376
end
13791377
end
13801378
end

0 commit comments

Comments
 (0)