diff --git a/src/inspect_ai/_eval/evalset.py b/src/inspect_ai/_eval/evalset.py index b41faf5dc5..e83ddcc440 100644 --- a/src/inspect_ai/_eval/evalset.py +++ b/src/inspect_ai/_eval/evalset.py @@ -1,5 +1,6 @@ import hashlib import logging +import time from dataclasses import dataclass from typing import Any, Literal, NamedTuple, Set, cast @@ -15,7 +16,7 @@ stop_after_attempt, wait_exponential, ) -from typing_extensions import Unpack +from typing_extensions import Unpack, override from inspect_ai._display import display as display_manager from inspect_ai._eval.task.log import plan_to_eval_plan @@ -78,6 +79,51 @@ class Log(NamedTuple): task_identifier: str +class EvalSetScanProgress(ReadEvalLogsProgress): + """Logs progress while scanning for existing eval logs during resume.""" + + def __init__(self, log_dir: str) -> None: + self.log_dir = log_dir + self.total_files = 0 + self.completed = 0 + self.start_time: float | None = None + self._last_log_count = 0 + + @override + def before_reading_logs(self, total_files: int) -> None: + self.total_files = total_files + self.start_time = time.monotonic() + if total_files > 0: + logger.info( + "Found %d eval log files in %s, reading headers...", + total_files, + self.log_dir, + ) + + @override + def after_read_log(self, log_file: str) -> None: + self.completed += 1 + interval = max(100, self.total_files // 10) + if ( + self.completed - self._last_log_count >= interval + or self.completed == self.total_files + ): + logger.info("Reading eval logs: %d/%d", self.completed, self.total_files) + self._last_log_count = self.completed + + def log_completion(self, completed_count: int, pending_count: int) -> None: + """Log final summary after scan completes.""" + if self.start_time is not None and self.total_files > 0: + duration = time.monotonic() - self.start_time + logger.info( + "Resume scan complete: %d logs read in %.1fs (%d completed, %d pending)", + self.total_files, + duration, + completed_count, + pending_count, + ) + + @dataclass class EvalSetArgsInTaskIdentifier: config: GenerateConfig @@ -390,7 +436,8 @@ def try_eval() -> list[EvalLog]: ) # list all logs currently in the log directory (update manifest if there are some) - all_logs = list_all_eval_logs(log_dir) + scan_progress = EvalSetScanProgress(log_dir) + all_logs = list_all_eval_logs(log_dir, progress=scan_progress) if len(all_logs) > 0: write_log_dir_manifest(log_dir) @@ -435,6 +482,13 @@ def try_eval() -> list[EvalLog]: pending_tasks = [ task[1] for task in all_tasks if task[0] not in log_task_identifiers ] + + # Log scan completion summary + scan_progress.log_completion( + completed_count=len(log_task_identifiers), + pending_count=len(pending_tasks), + ) + tasks_to_run: ( list[ResolvedTask | PreviousTask] | list[ResolvedTask] | list[PreviousTask] ) diff --git a/tests/_eval/test_evalset_progress.py b/tests/_eval/test_evalset_progress.py new file mode 100644 index 0000000000..afd55d75ec --- /dev/null +++ b/tests/_eval/test_evalset_progress.py @@ -0,0 +1,86 @@ +import logging + +import pytest + +from inspect_ai._eval.evalset import EvalSetScanProgress + + +class TestEvalSetScanProgress: + def test_before_reading_logs_zero_files(self, caplog: pytest.LogCaptureFixture) -> None: + """Zero files should produce no log output.""" + progress = EvalSetScanProgress("s3://bucket/logs") + with caplog.at_level(logging.INFO): + progress.before_reading_logs(0) + assert len(caplog.records) == 0 + + def test_before_reading_logs_many_files(self, caplog: pytest.LogCaptureFixture) -> None: + """Many files should log initial count.""" + progress = EvalSetScanProgress("s3://bucket/logs") + with caplog.at_level(logging.INFO): + progress.before_reading_logs(500) + assert "500 eval log files" in caplog.text + assert "s3://bucket/logs" in caplog.text + + def test_progress_logging_interval(self, caplog: pytest.LogCaptureFixture) -> None: + """Progress should log at intervals, not every file.""" + progress = EvalSetScanProgress("s3://bucket/logs") + progress.before_reading_logs(1000) + + with caplog.at_level(logging.INFO): + caplog.clear() + for i in range(1000): + progress.after_read_log(f"file_{i}.eval") + + progress_lines = [r for r in caplog.records if "Reading eval logs:" in r.message] + assert len(progress_lines) == 10 + + def test_progress_logging_small_set(self, caplog: pytest.LogCaptureFixture) -> None: + """Small file sets should still log at reasonable intervals.""" + progress = EvalSetScanProgress("s3://bucket/logs") + progress.before_reading_logs(50) + + with caplog.at_level(logging.INFO): + caplog.clear() + for i in range(50): + progress.after_read_log(f"file_{i}.eval") + + progress_lines = [r for r in caplog.records if "Reading eval logs:" in r.message] + assert len(progress_lines) == 1 + assert "50/50" in progress_lines[0].message + + def test_completion_summary(self, caplog: pytest.LogCaptureFixture) -> None: + """Completion should log summary with timing.""" + progress = EvalSetScanProgress("s3://bucket/logs") + progress.before_reading_logs(100) + progress.completed = 100 + + with caplog.at_level(logging.INFO): + caplog.clear() + progress.log_completion(completed_count=80, pending_count=20) + + assert "80 completed" in caplog.text + assert "20 pending" in caplog.text + assert "100 logs read" in caplog.text + + def test_completion_no_logs_when_zero_files( + self, caplog: pytest.LogCaptureFixture + ) -> None: + """Completion should not log if no files were scanned.""" + progress = EvalSetScanProgress("s3://bucket/logs") + progress.before_reading_logs(0) + + with caplog.at_level(logging.INFO): + caplog.clear() + progress.log_completion(completed_count=0, pending_count=5) + + assert len(caplog.records) == 0 + + def test_completion_no_logs_before_start( + self, caplog: pytest.LogCaptureFixture + ) -> None: + """Completion should not log if before_reading_logs was never called.""" + progress = EvalSetScanProgress("s3://bucket/logs") + with caplog.at_level(logging.INFO): + progress.log_completion(completed_count=80, pending_count=20) + + assert len(caplog.records) == 0