diff --git a/alf/bin/train.py b/alf/bin/train.py index ff9a6f58e..05fb444e8 100644 --- a/alf/bin/train.py +++ b/alf/bin/train.py @@ -118,16 +118,19 @@ def check_valid_launch(): assert not extra, f"Unexpected environment variables for non-distributed launch: {extra}" -def _setup_logging(rank: int, log_dir: str): +def _setup_logging(log_dir: str): """Setup logging for each process Args: - rank (int): The ID of the process among all of the DDP processes log_dir (str): path to the directory where log files are written to """ FLAGS.alsologtostderr = True logging.set_verbosity(logging.INFO) logging.get_absl_handler().use_absl_log_file(log_dir=log_dir) + # Spawned subprocesses create a new interpreter so will change the + # default logging back to python's logging module. + # For DDP worker logging to work, we need to explicitly set it back to absl. + logging.use_absl_handler() def _setup_device(): @@ -232,7 +235,7 @@ def training_worker(rank: int, in different worker processes, if multi-gpu training is used. """ try: - _setup_logging(log_dir=root_dir, rank=rank) + _setup_logging(log_dir=root_dir) _setup_device() if world_size > 1: # Specialization for distributed mode @@ -298,7 +301,7 @@ def training_worker_multi_node(local_rank: int, in different worker processes, if multi-gpu training is used. """ try: - _setup_logging(log_dir=root_dir, rank=rank) + _setup_logging(log_dir=root_dir) _setup_device() # Specialization for distributed mode