diff --git a/daisy/serial_server.py b/daisy/serial_server.py index dc586ac0..b44e01c8 100644 --- a/daisy/serial_server.py +++ b/daisy/serial_server.py @@ -1,5 +1,6 @@ from .block import BlockStatus from .scheduler import Scheduler +from .task import Task from .task_state import TaskState from .server_observer import ServerObservee import logging @@ -11,14 +12,16 @@ class SerialServer(ServerObservee): def __init__(self): super().__init__() - def run_blockwise(self, tasks, scheduler=None) -> dict[str, TaskState]: + def run_blockwise( + self, tasks: list[Task], scheduler=None + ) -> dict[str, TaskState]: if scheduler is None: scheduler = Scheduler(tasks) else: scheduler = scheduler started_tasks = set() - finished_tasks = set() + finished_tasks: set[str] = set() all_tasks = set(task.task_id for task in tasks) process_funcs = {task.task_id: task.process_function for task in tasks} @@ -63,3 +66,4 @@ def run_blockwise(self, tasks, scheduler=None) -> dict[str, TaskState]: if len(process_funcs) == 0: self.notify_server_exit() return scheduler.task_states + raise NotImplementedError("Unreachable") diff --git a/daisy/server.py b/daisy/server.py index 7eaf389b..de88e840 100644 --- a/daisy/server.py +++ b/daisy/server.py @@ -25,9 +25,7 @@ class Server(ServerObservee): - def __init__(self, stop_event=None): - super().__init__() if stop_event is None: @@ -41,7 +39,6 @@ def __init__(self, stop_event=None): logger.debug("Started server listening at %s:%s", self.hostname, self.port) def run_blockwise(self, tasks, scheduler=None) -> dict[str, TaskState]: - if scheduler is None: self.scheduler = Scheduler(tasks) else: @@ -49,11 +46,13 @@ def run_blockwise(self, tasks, scheduler=None) -> dict[str, TaskState]: self.worker_pools = TaskWorkerPools(tasks, self) self.block_bookkeeper = BlockBookkeeper() - self.started_tasks = set() - self.finished_tasks = set() + self.started_tasks: set[str] = set() + self.finished_tasks: set[str] = set() self.all_done = False - self.pending_requests = {task.task_id: Queue() for task in tasks} + self.pending_requests: dict[str, Queue] = { + task.task_id: Queue() for task in tasks + } self._recruit_workers() @@ -82,7 +81,6 @@ def _event_loop(self): last_time = current_time def _get_client_message(self): - try: message = self.tcp_server.get_message(timeout=0.1) except StreamClosedError: @@ -92,7 +90,6 @@ def _get_client_message(self): return message for task_id, requests in self.pending_requests.items(): - if self.pending_requests[task_id].empty(): continue @@ -100,14 +97,12 @@ def _get_client_message(self): return self.pending_requests[task_id].get() def _send_client_message(self, stream, message): - try: stream.send_message(message) except StreamClosedError: pass def _handle_client_messages(self): - message = self._get_client_message() if message is None: @@ -116,7 +111,6 @@ def _handle_client_messages(self): self._handle_client_message(message) def _handle_client_message(self, message): - if isinstance(message, AcquireBlock): self._handle_acquire_block(message) elif isinstance(message, ReleaseBlock): @@ -129,7 +123,6 @@ def _handle_client_message(self, message): self._check_all_tasks_completed() def _handle_acquire_block(self, message): - logger.debug("Received block request for task %s", message.task_id) task_state = self.scheduler.task_states[message.task_id] @@ -139,11 +132,9 @@ def _handle_acquire_block(self, message): block = self.scheduler.acquire_block(message.task_id) if block is None: - assert task_state.ready_count == 0 if task_state.pending_count == 0: - logger.debug( "No more pending blocks for task %s, terminating " "client", message.task_id, @@ -168,7 +159,6 @@ def _handle_acquire_block(self, message): self.pending_requests[message.task_id].put(message) else: - try: logger.debug("Sending block %s to client", block) self._send_client_message(message.stream, SendBlock(block)) @@ -178,7 +168,6 @@ def _handle_acquire_block(self, message): self.notify_acquire_block(message.task_id, task_state) def _handle_release_block(self, message): - logger.debug("Client releases block %s", message.block) self._safe_release_block(message.block, message.stream) @@ -232,9 +221,7 @@ def _safe_release_block(self, block, stream): ) def _handle_client_exception(self, message): - if isinstance(message, BlockFailed): - logger.error( "Block %s failed in worker %s with %s", message.block, @@ -252,7 +239,6 @@ def _handle_client_exception(self, message): raise message.exception def _recruit_workers(self): - ready_tasks = self.scheduler.get_ready_tasks() ready_tasks = {task.task_id: task for task in ready_tasks} @@ -273,12 +259,10 @@ def _recruit_workers(self): self.worker_pools.recruit_workers(ready_tasks) def _check_for_lost_blocks(self): - lost_blocks = self.block_bookkeeper.get_lost_blocks() # mark as failed and release the lost blocks for block in lost_blocks: - logger.error("Block %s was lost, returning it to scheduler", block) block.status = BlockStatus.FAILED self._release_block(block)