Skip to content

Commit

Permalink
Set runspec on all new tasks to avoid deadlocks (#4432)
Browse files Browse the repository at this point in the history
* Set runspec in add_task
* Don't transition task to waiting prior to releasing
* Also, check that we don't release keys with dependents
* transition instead of release
* Clear who_has on released task
* Allow constrained error transition
* remove unnecessary definition for transition_ready_waiting
  • Loading branch information
fjetter authored Jan 18, 2021
1 parent b21eac3 commit 9442d9b
Showing 1 changed file with 67 additions and 12 deletions.
79 changes: 67 additions & 12 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,10 @@ def __init__(
("waiting", "flight"): self.transition_waiting_flight,
("ready", "executing"): self.transition_ready_executing,
("ready", "memory"): self.transition_ready_memory,
("ready", "error"): self.transition_ready_error,
("ready", "waiting"): self.transition_ready_waiting,
("constrained", "waiting"): self.transition_ready_waiting,
("constrained", "error"): self.transition_ready_error,
("constrained", "executing"): self.transition_constrained_executing,
("executing", "memory"): self.transition_executing_done,
("executing", "error"): self.transition_executing_done,
Expand Down Expand Up @@ -1446,8 +1450,10 @@ def add_task(
**kwargs2,
):
try:
runspec = SerializedTask(function, args, kwargs, task)
if key in self.tasks:
ts = self.tasks[key]
ts.runspec = runspec
if ts.state == "memory":
assert key in self.data or key in self.actors
logger.debug(
Expand Down Expand Up @@ -1476,6 +1482,7 @@ def add_task(
if actor:
self.actors[ts.key] = None

ts.runspec = runspec
ts.priority = priority
ts.duration = duration
if resource_restrictions:
Expand Down Expand Up @@ -1561,13 +1568,14 @@ def transition_waiting_flight(self, ts, worker=None):
pdb.set_trace()
raise

def transition_flight_waiting(self, ts, worker=None, remove=True):
def transition_flight_waiting(self, ts, worker=None, remove=True, runspec=None):
try:
if self.validate:
assert ts.state == "flight"

self.in_flight_tasks -= 1
ts.coming_from = None
ts.runspec = runspec or ts.runspec
if remove:
try:
ts.who_has.remove(worker)
Expand Down Expand Up @@ -1692,9 +1700,23 @@ def transition_ready_executing(self, ts):
pdb.set_trace()
raise

def transition_ready_error(self, ts):
if self.validate:
assert ts.exception is not None
assert ts.traceback is not None
self.send_task_state_to_scheduler(ts)

def transition_ready_memory(self, ts, value=None):
if value:
self.put_key_in_memory(ts, value=value)
self.send_task_state_to_scheduler(ts)

def transition_ready_waiting(self, ts):
"""
This transition is common for work stealing
"""
pass

def transition_constrained_executing(self, ts):
self.transition_ready_executing(ts)
for resource, quantity in ts.resource_restrictions.items():
Expand Down Expand Up @@ -2145,7 +2167,6 @@ async def gather_dep(self, worker, dep, deps, total_nbytes, cause=None):
self.repetitively_busy += 1
await asyncio.sleep(0.100 * 1.5 ** self.repetitively_busy)

# See if anyone new has the data
await self.query_who_has(dep.key)
self.ensure_communicating()

Expand Down Expand Up @@ -2227,10 +2248,11 @@ def update_who_has(self, who_has):
if not workers:
continue

self.tasks[dep].who_has.update(workers)
if dep in self.tasks:
self.tasks[dep].who_has.update(workers)

for worker in workers:
self.has_what[worker].add(dep)
for worker in workers:
self.has_what[worker].add(dep)
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand All @@ -2243,20 +2265,34 @@ def steal_request(self, key):
# There may be a race condition between stealing and releasing a task.
# In this case the self.tasks is already cleared. The `None` will be
# registered as `already-computing` on the other end
ts = self.tasks.get(key)
if key in self.tasks:
state = self.tasks[key].state
state = ts.state
else:
state = None

response = {"op": "steal-response", "key": key, "state": state}
self.batched_stream.send(response)

if state in ("ready", "waiting", "constrained"):
self.release_key(key)
# Resetting the runspec should be reset by the transition. However,
# the waiting->waiting transition results in a no-op which would not
# reset.
ts.runspec = None
self.transition(ts, "waiting")
if not ts.dependents:
self.release_key(ts.key)
if self.validate:
assert ts.key not in self.tasks
if self.validate:
assert ts.runspec is None

def release_key(self, key, cause=None, reason=None, report=True):
try:
if self.validate:
assert isinstance(key, str)
ts = self.tasks.get(key, TaskState(key=key))

if cause:
self.log.append((key, "release-key", {"cause": cause}))
else:
Expand All @@ -2280,6 +2316,7 @@ def release_key(self, key, cause=None, reason=None, report=True):

for worker in ts.who_has:
self.has_what[worker].discard(ts.key)
ts.who_has.clear()

if key in self.threads:
del self.threads[key]
Expand All @@ -2296,7 +2333,7 @@ def release_key(self, key, cause=None, reason=None, report=True):
self.batched_stream.send({"op": "release", "key": key, "cause": cause})

self._notify_plugins("release_key", key, ts.state, cause, reason, report)
if key in self.tasks:
if key in self.tasks and not ts.dependents:
self.tasks.pop(key)
del ts
except CommClosedError:
Expand Down Expand Up @@ -2532,8 +2569,15 @@ async def execute(self, key, report=False):
if key not in self.tasks:
return
ts = self.tasks[key]
if ts.state != "executing" or ts.runspec is None:
if ts.state != "executing":
# This might happen if keys are canceled
logger.debug(
"Trying to execute a task %s which is not in executing state anymore"
% ts
)
return
if ts.runspec is None:
logger.critical("No runspec available for task %s." % ts)
if self.validate:
assert not ts.waiting_for_data
assert ts.state == "executing"
Expand Down Expand Up @@ -2583,7 +2627,17 @@ async def execute(self, key, report=False):
executor_error = e
raise

if ts.state not in ("executing", "long-running"):
# We'll need to check again for the task state since it may have
# changed since the execution was kicked off. In particular, it may
# have been canceled and released already in which case we'll have
# to drop the result immediately
key = ts.key
ts = self.tasks.get(key)

if ts is None:
logger.debug(
"Dropping result for %s since task has already been released." % key
)
return

result["key"] = ts.key
Expand Down Expand Up @@ -2876,13 +2930,14 @@ def _notify_plugins(self, method_name, *args, **kwargs):

def validate_task_memory(self, ts):
assert ts.key in self.data or ts.key in self.actors
assert ts.nbytes is not None
assert isinstance(ts.nbytes, int)
assert not ts.waiting_for_data
assert ts.key not in self.ready
assert ts.state == "memory"

def validate_task_executing(self, ts):
assert ts.state == "executing"
assert ts.runspec is not None
assert ts.key not in self.data
assert not ts.waiting_for_data
assert all(
Expand All @@ -2901,7 +2956,7 @@ def validate_task_ready(self, ts):
def validate_task_waiting(self, ts):
assert ts.key not in self.data
assert ts.state == "waiting"
if ts.dependencies:
if ts.dependencies and ts.runspec:
assert not all(dep.key in self.data for dep in ts.dependencies)

def validate_task_flight(self, ts):
Expand Down

0 comments on commit 9442d9b

Please sign in to comment.