44import pickle
55import queue
66import subprocess
7+ from dataclasses import dataclass , field
78from enum import Enum
89from typing import Dict , List , Optional
910
@@ -69,6 +70,14 @@ def __str__(self):
6970 return f"MPINodesUnavailable(requested={ self .requested } available={ self .available } )"
7071
7172
73+ @dataclass (order = True )
74+ class PrioritizedTask :
75+ # Comparing dict will fail since they are unhashable
76+ # This dataclass limits comparison to the priority field
77+ priority : int
78+ task : Dict = field (compare = False )
79+
80+
7281class TaskScheduler :
7382 """Default TaskScheduler that does no taskscheduling
7483
@@ -111,7 +120,7 @@ def __init__(
111120 super ().__init__ (pending_task_q , pending_result_q )
112121 self .scheduler = identify_scheduler ()
113122 # PriorityQueue is threadsafe
114- self ._backlog_queue : queue .PriorityQueue = queue .PriorityQueue ()
123+ self ._backlog_queue : queue .PriorityQueue [ PrioritizedTask ] = queue .PriorityQueue ()
115124 self ._map_tasks_to_nodes : Dict [str , List [str ]] = {}
116125 self .available_nodes = get_nodes_in_batchjob (self .scheduler )
117126 self ._free_node_counter = SpawnContext .Value ("i" , len (self .available_nodes ))
@@ -169,7 +178,8 @@ def put_task(self, task_package: dict):
169178 allocated_nodes = self ._get_nodes (nodes_needed )
170179 except MPINodesUnavailable :
171180 logger .info (f"Not enough resources, placing task { tid } into backlog" )
172- self ._backlog_queue .put ((nodes_needed , task_package ))
181+ # Negate the priority element so that larger tasks are prioritized
182+ self ._backlog_queue .put (PrioritizedTask (- 1 * nodes_needed , task_package ))
173183 return
174184 else :
175185 resource_spec ["MPI_NODELIST" ] = "," .join (allocated_nodes )
@@ -182,14 +192,16 @@ def put_task(self, task_package: dict):
182192
183193 def _schedule_backlog_tasks (self ):
184194 """Attempt to schedule backlogged tasks"""
185- try :
186- _nodes_requested , task_package = self ._backlog_queue .get (block = False )
187- self .put_task (task_package )
188- except queue .Empty :
189- return
190- else :
191- # Keep attempting to schedule tasks till we are out of resources
192- self ._schedule_backlog_tasks ()
195+
196+ # Separate fetching tasks from the _backlog_queue and scheduling them
197+ # since tasks that failed to schedule will be pushed to the _backlog_queue
198+ backlogged_tasks = []
199+ while not self ._backlog_queue .empty ():
200+ prioritized_task = self ._backlog_queue .get (block = False )
201+ backlogged_tasks .append (prioritized_task .task )
202+
203+ for backlogged_task in backlogged_tasks :
204+ self .put_task (backlogged_task )
193205
194206 def get_result (self , block : bool = True , timeout : Optional [float ] = None ):
195207 """Return result and relinquish provisioned nodes"""
0 commit comments