Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions tinker_cookbook/utils/ml_log.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Simplified logging utilities for tinker-cookbook."""

import atexit
import json
import logging
import os
Expand Down Expand Up @@ -205,6 +206,14 @@ def log_metrics(self, metrics: dict[str, Any], step: int | None = None) -> None:
self._store.write_metrics(metrics, step)
logger.info("Wrote metrics to %s/metrics.jsonl", self.log_dir)

def sync(self) -> None:
"""Flush buffered store writes (uploads staged data on cloud backends)."""
self._store.flush()

def close(self) -> None:
"""Flush buffered store writes before shutdown."""
self.sync()


class PrettyPrintLogger(Logger):
"""Logger that displays metrics as a Rich-formatted table in the console.
Expand Down Expand Up @@ -608,6 +617,10 @@ def setup_logging(
# Create multiplex logger
ml_logger = MultiplexLogger(loggers)

# Flush staged writes on interpreter exit so a non-clean shutdown (before
# ml_logger.close()) still uploads metrics/timing/checkpoints on cloud backends.
atexit.register(ml_logger.sync)

# Log initial configuration
if config is not None:
ml_logger.log_hparams(config)
Expand Down
36 changes: 35 additions & 1 deletion tinker_cookbook/utils/ml_log_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import sys
from unittest.mock import patch

from .ml_log import configure_logging_module
from tinker_cookbook.stores.storage import LocalStorage
from tinker_cookbook.stores.training_store import TrainingRunStore

from .ml_log import JsonLogger, configure_logging_module, setup_logging


def _flush_root_handlers() -> None:
Expand Down Expand Up @@ -40,3 +43,34 @@ def test_configure_logging_module_logs_invocation_and_appends(tmp_path):
assert final_contents.count("Command line invocation:") == 2
assert final_contents.index("first message") < final_contents.index(second_invocation)
assert final_contents.index(second_invocation) < final_contents.index("second message")


def test_json_logger_sync_and_close_flush_store(tmp_path):
"""JsonLogger.sync()/close() flush the store so staged cloud writes upload."""

class _FlushCountingStorage(LocalStorage):
flush_count = 0

def flush(self) -> None:
self.flush_count += 1

storage = _FlushCountingStorage(tmp_path)
json_logger = JsonLogger(tmp_path, store=TrainingRunStore(storage))

json_logger.log_metrics({"loss": 1.0}, step=0)
json_logger.sync()
json_logger.close()

assert storage.flush_count == 2


def test_setup_logging_registers_atexit_flush(tmp_path, monkeypatch):
"""setup_logging registers an atexit flush so a non-clean exit still uploads."""
import atexit

registered = []
monkeypatch.setattr(atexit, "register", lambda fn, *a, **k: registered.append(fn))

ml_logger = setup_logging(str(tmp_path), do_configure_logging_module=False)

assert ml_logger.sync in registered
Loading