|
1 | | -import logging |
2 | 1 | import os |
3 | 2 | import pickle |
| 3 | +import random |
4 | 4 | from unittest import mock |
5 | 5 |
|
6 | 6 | import pytest |
@@ -218,3 +218,33 @@ def test_tiny_large_loop(): |
218 | 218 | got_result = scheduler.get_result(True, 1) |
219 | 219 |
|
220 | 220 | assert got_result == result_pkl |
| 221 | + |
| 222 | + |
| 223 | +@pytest.mark.local |
| 224 | +def test_larger_jobs_prioritized(): |
| 225 | + """Larger jobs should be scheduled first""" |
| 226 | + |
| 227 | + task_q, result_q = SpawnContext.Queue(), SpawnContext.Queue() |
| 228 | + scheduler = MPITaskScheduler(task_q, result_q) |
| 229 | + |
| 230 | + max_nodes = len(scheduler.available_nodes) |
| 231 | + |
| 232 | + # The first task will get scheduled with all the nodes, |
| 233 | + # and the remainder hits the backlog queue. |
| 234 | + node_request_list = [max_nodes] + [random.randint(1, 4) for _i in range(8)] |
| 235 | + |
| 236 | + for task_id, num_nodes in enumerate(node_request_list): |
| 237 | + mock_task_buffer = pack_res_spec_apply_message("func", "args", "kwargs", |
| 238 | + resource_specification={ |
| 239 | + "num_nodes": num_nodes, |
| 240 | + "ranks_per_node": 2 |
| 241 | + }) |
| 242 | + task_package = {"task_id": task_id, "buffer": mock_task_buffer} |
| 243 | + scheduler.put_task(task_package) |
| 244 | + |
| 245 | + # Confirm that the tasks are sorted in decreasing order |
| 246 | + prev_priority = 0 |
| 247 | + for i in range(len(node_request_list) - 1): |
| 248 | + p_task = scheduler._backlog_queue.get() |
| 249 | + assert p_task.priority < 0 |
| 250 | + assert p_task.priority <= prev_priority |
0 commit comments