Skip to content

Commit

Permalink
fix mypy errors with improved type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Oct 25, 2024
1 parent c6f1224 commit c550aff
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 23 deletions.
8 changes: 6 additions & 2 deletions daisy/serial_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .block import BlockStatus
from .scheduler import Scheduler
from .task import Task
from .task_state import TaskState
from .server_observer import ServerObservee
import logging
Expand All @@ -11,14 +12,16 @@ class SerialServer(ServerObservee):
def __init__(self):
super().__init__()

def run_blockwise(self, tasks, scheduler=None) -> dict[str, TaskState]:
def run_blockwise(
self, tasks: list[Task], scheduler=None
) -> dict[str, TaskState]:
if scheduler is None:
scheduler = Scheduler(tasks)
else:
scheduler = scheduler

started_tasks = set()
finished_tasks = set()
finished_tasks: set[str] = set()
all_tasks = set(task.task_id for task in tasks)
process_funcs = {task.task_id: task.process_function for task in tasks}

Expand Down Expand Up @@ -63,3 +66,4 @@ def run_blockwise(self, tasks, scheduler=None) -> dict[str, TaskState]:
if len(process_funcs) == 0:
self.notify_server_exit()
return scheduler.task_states
raise NotImplementedError("Unreachable")
26 changes: 5 additions & 21 deletions daisy/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@


class Server(ServerObservee):

def __init__(self, stop_event=None):

super().__init__()

if stop_event is None:
Expand All @@ -41,19 +39,20 @@ def __init__(self, stop_event=None):
logger.debug("Started server listening at %s:%s", self.hostname, self.port)

def run_blockwise(self, tasks, scheduler=None) -> dict[str, TaskState]:

if scheduler is None:
self.scheduler = Scheduler(tasks)
else:
self.scheduler = scheduler

self.worker_pools = TaskWorkerPools(tasks, self)
self.block_bookkeeper = BlockBookkeeper()
self.started_tasks = set()
self.finished_tasks = set()
self.started_tasks: set[str] = set()
self.finished_tasks: set[str] = set()
self.all_done = False

self.pending_requests = {task.task_id: Queue() for task in tasks}
self.pending_requests: dict[str, Queue] = {
task.task_id: Queue() for task in tasks
}

self._recruit_workers()

Expand Down Expand Up @@ -82,7 +81,6 @@ def _event_loop(self):
last_time = current_time

def _get_client_message(self):

try:
message = self.tcp_server.get_message(timeout=0.1)
except StreamClosedError:
Expand All @@ -92,22 +90,19 @@ def _get_client_message(self):
return message

for task_id, requests in self.pending_requests.items():

if self.pending_requests[task_id].empty():
continue

logger.debug("Answering delayed request for task %s", task_id)
return self.pending_requests[task_id].get()

def _send_client_message(self, stream, message):

try:
stream.send_message(message)
except StreamClosedError:
pass

def _handle_client_messages(self):

message = self._get_client_message()

if message is None:
Expand All @@ -116,7 +111,6 @@ def _handle_client_messages(self):
self._handle_client_message(message)

def _handle_client_message(self, message):

if isinstance(message, AcquireBlock):
self._handle_acquire_block(message)
elif isinstance(message, ReleaseBlock):
Expand All @@ -129,7 +123,6 @@ def _handle_client_message(self, message):
self._check_all_tasks_completed()

def _handle_acquire_block(self, message):

logger.debug("Received block request for task %s", message.task_id)

task_state = self.scheduler.task_states[message.task_id]
Expand All @@ -139,11 +132,9 @@ def _handle_acquire_block(self, message):
block = self.scheduler.acquire_block(message.task_id)

if block is None:

assert task_state.ready_count == 0

if task_state.pending_count == 0:

logger.debug(
"No more pending blocks for task %s, terminating " "client",
message.task_id,
Expand All @@ -168,7 +159,6 @@ def _handle_acquire_block(self, message):
self.pending_requests[message.task_id].put(message)

else:

try:
logger.debug("Sending block %s to client", block)
self._send_client_message(message.stream, SendBlock(block))
Expand All @@ -178,7 +168,6 @@ def _handle_acquire_block(self, message):
self.notify_acquire_block(message.task_id, task_state)

def _handle_release_block(self, message):

logger.debug("Client releases block %s", message.block)
self._safe_release_block(message.block, message.stream)

Expand Down Expand Up @@ -232,9 +221,7 @@ def _safe_release_block(self, block, stream):
)

def _handle_client_exception(self, message):

if isinstance(message, BlockFailed):

logger.error(
"Block %s failed in worker %s with %s",
message.block,
Expand All @@ -252,7 +239,6 @@ def _handle_client_exception(self, message):
raise message.exception

def _recruit_workers(self):

ready_tasks = self.scheduler.get_ready_tasks()
ready_tasks = {task.task_id: task for task in ready_tasks}

Expand All @@ -273,12 +259,10 @@ def _recruit_workers(self):
self.worker_pools.recruit_workers(ready_tasks)

def _check_for_lost_blocks(self):

lost_blocks = self.block_bookkeeper.get_lost_blocks()

# mark as failed and release the lost blocks
for block in lost_blocks:

logger.error("Block %s was lost, returning it to scheduler", block)
block.status = BlockStatus.FAILED
self._release_block(block)

0 comments on commit c550aff

Please sign in to comment.