Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions src/inspect_ai/_eval/evalset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import hashlib
import logging
import time
from dataclasses import dataclass
from typing import Any, Literal, NamedTuple, Set, cast

Expand All @@ -15,7 +16,7 @@
stop_after_attempt,
wait_exponential,
)
from typing_extensions import Unpack
from typing_extensions import Unpack, override

from inspect_ai._display import display as display_manager
from inspect_ai._eval.task.log import plan_to_eval_plan
Expand Down Expand Up @@ -78,6 +79,51 @@ class Log(NamedTuple):
task_identifier: str


class EvalSetScanProgress(ReadEvalLogsProgress):
"""Logs progress while scanning for existing eval logs during resume."""

def __init__(self, log_dir: str) -> None:
self.log_dir = log_dir
self.total_files = 0
self.completed = 0
self.start_time: float | None = None
self._last_log_count = 0

@override
def before_reading_logs(self, total_files: int) -> None:
self.total_files = total_files
self.start_time = time.monotonic()
if total_files > 0:
logger.info(
"Found %d eval log files in %s, reading headers...",
total_files,
self.log_dir,
)

@override
def after_read_log(self, log_file: str) -> None:
self.completed += 1
interval = max(100, self.total_files // 10)
if (
self.completed - self._last_log_count >= interval
or self.completed == self.total_files
):
logger.info("Reading eval logs: %d/%d", self.completed, self.total_files)
self._last_log_count = self.completed

def log_completion(self, completed_count: int, pending_count: int) -> None:
"""Log final summary after scan completes."""
if self.start_time is not None and self.total_files > 0:
duration = time.monotonic() - self.start_time
logger.info(
"Resume scan complete: %d logs read in %.1fs (%d completed, %d pending)",
self.total_files,
duration,
completed_count,
pending_count,
)


@dataclass
class EvalSetArgsInTaskIdentifier:
config: GenerateConfig
Expand Down Expand Up @@ -390,7 +436,8 @@ def try_eval() -> list[EvalLog]:
)

# list all logs currently in the log directory (update manifest if there are some)
all_logs = list_all_eval_logs(log_dir)
scan_progress = EvalSetScanProgress(log_dir)
all_logs = list_all_eval_logs(log_dir, progress=scan_progress)
if len(all_logs) > 0:
write_log_dir_manifest(log_dir)

Expand Down Expand Up @@ -435,6 +482,13 @@ def try_eval() -> list[EvalLog]:
pending_tasks = [
task[1] for task in all_tasks if task[0] not in log_task_identifiers
]

# Log scan completion summary
scan_progress.log_completion(
completed_count=len(log_task_identifiers),
pending_count=len(pending_tasks),
)

tasks_to_run: (
list[ResolvedTask | PreviousTask] | list[ResolvedTask] | list[PreviousTask]
)
Expand Down
86 changes: 86 additions & 0 deletions tests/_eval/test_evalset_progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import logging

import pytest

from inspect_ai._eval.evalset import EvalSetScanProgress


class TestEvalSetScanProgress:
def test_before_reading_logs_zero_files(self, caplog: pytest.LogCaptureFixture) -> None:
"""Zero files should produce no log output."""
progress = EvalSetScanProgress("s3://bucket/logs")
with caplog.at_level(logging.INFO):
progress.before_reading_logs(0)
assert len(caplog.records) == 0

def test_before_reading_logs_many_files(self, caplog: pytest.LogCaptureFixture) -> None:
"""Many files should log initial count."""
progress = EvalSetScanProgress("s3://bucket/logs")
with caplog.at_level(logging.INFO):
progress.before_reading_logs(500)
assert "500 eval log files" in caplog.text
assert "s3://bucket/logs" in caplog.text

def test_progress_logging_interval(self, caplog: pytest.LogCaptureFixture) -> None:
"""Progress should log at intervals, not every file."""
progress = EvalSetScanProgress("s3://bucket/logs")
progress.before_reading_logs(1000)

with caplog.at_level(logging.INFO):
caplog.clear()
for i in range(1000):
progress.after_read_log(f"file_{i}.eval")

progress_lines = [r for r in caplog.records if "Reading eval logs:" in r.message]
assert len(progress_lines) == 10

def test_progress_logging_small_set(self, caplog: pytest.LogCaptureFixture) -> None:
"""Small file sets should still log at reasonable intervals."""
progress = EvalSetScanProgress("s3://bucket/logs")
progress.before_reading_logs(50)

with caplog.at_level(logging.INFO):
caplog.clear()
for i in range(50):
progress.after_read_log(f"file_{i}.eval")

progress_lines = [r for r in caplog.records if "Reading eval logs:" in r.message]
assert len(progress_lines) == 1
assert "50/50" in progress_lines[0].message

def test_completion_summary(self, caplog: pytest.LogCaptureFixture) -> None:
"""Completion should log summary with timing."""
progress = EvalSetScanProgress("s3://bucket/logs")
progress.before_reading_logs(100)
progress.completed = 100

with caplog.at_level(logging.INFO):
caplog.clear()
progress.log_completion(completed_count=80, pending_count=20)

assert "80 completed" in caplog.text
assert "20 pending" in caplog.text
assert "100 logs read" in caplog.text

def test_completion_no_logs_when_zero_files(
self, caplog: pytest.LogCaptureFixture
) -> None:
"""Completion should not log if no files were scanned."""
progress = EvalSetScanProgress("s3://bucket/logs")
progress.before_reading_logs(0)

with caplog.at_level(logging.INFO):
caplog.clear()
progress.log_completion(completed_count=0, pending_count=5)

assert len(caplog.records) == 0

def test_completion_no_logs_before_start(
self, caplog: pytest.LogCaptureFixture
) -> None:
"""Completion should not log if before_reading_logs was never called."""
progress = EvalSetScanProgress("s3://bucket/logs")
with caplog.at_level(logging.INFO):
progress.log_completion(completed_count=80, pending_count=20)

assert len(caplog.records) == 0