diff --git a/.gitignore b/.gitignore index b43881e8b..df5d95bcf 100644 --- a/.gitignore +++ b/.gitignore @@ -97,5 +97,5 @@ website/pages/tutorials/* **/.ipynb_checkpoints/** # Configs for local development -configs/config_local/* +configs/config/config_local/* train_config.yaml diff --git a/dev/__init__.py b/dev/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dev/benchmark_suite/benchmark_suite_scheduler_defaults.json b/dev/benchmark_suite/benchmark_suite_scheduler_defaults.json new file mode 100644 index 000000000..64ff04cc3 --- /dev/null +++ b/dev/benchmark_suite/benchmark_suite_scheduler_defaults.json @@ -0,0 +1,22 @@ +{ + "params": { + "evaluation_iter_freq": -1, + "evaluation_phase_freq": -1, + "evaluate_final_phase": true, + "autoload_slurm_evaluator_checkpoint": false, + "slurm_evaluator_checkpoint": null, + "auto_retry_evaluations": false, + "retry_evaluation_job_ids": [], + "max_retries": 3, + "pytorch_ports": [40050] + }, + "slurm_options": { + "NAME": "vissl", + "COMMENT": "vissl evaluation job", + "CONSTRAINT": "", + "TIMEOUT_MIN": 4320, + "CPUS_PER_TASK": 8, + "MEM_GB": 16, + "ADDITIONAL_PARAMETERS": {} + } +} diff --git a/dev/benchmark_suite/benchmark_suite_scheduler_template.json b/dev/benchmark_suite/benchmark_suite_scheduler_template.json new file mode 100644 index 000000000..28f7c6f9c --- /dev/null +++ b/dev/benchmark_suite/benchmark_suite_scheduler_template.json @@ -0,0 +1,33 @@ +{ + "params": { + "training_checkpoint_dir": "(str) Training checkpoint directory. That is the CHECKPOINT.DIR of the training config", + "benchmarks": [ + { + "evaluation_name": "(str) Name of benchmark for convenience", + "config_files": [ + "config=path/to/evaluation/config", + "config.OVERRIDES=new_value" + ] + } + ], + "evaluation_iter_freq": "(int, default=-1) Evaluate the checkpoint every N iterations", + "evaluation_phase_freq": "(int, default=-1) Evaluate the checkpoint every N phases", + "evaluate_final_phase": "(bool, default=True) Evaluate the final phase", + "autoload_slurm_evaluator_checkpoint": "(bool, default=False) Whether or not to automatically load the benchmark checkpoint", + "slurm_evaluator_checkpoint": "(str, default=None) Path to load the benchmark checkpoint", + "auto_retry_evaluations": "(bool, default=False) Whether or not to automatically retry the evaluations", + "retry_evaluation_job_ids": "(array[int], default=[]) Array of job_ids to retry", + "max_retries": "(int, default=3) Maximum number of retries", + "pytorch_ports": "(List[int], default=[40500]) List of pytorch ports to cycle through as you are launching your evaluations, in order to prevent Pytorch DDP port colissions." + }, + "slurm_options": { + "PARTITION": "(str) Partition", + "NAME": "(str, default=vissl) Name of slurm job", + "COMMENT": "(str, default=vissl evaluation job) Comment of slurm job", + "CONSTRAINT": "(str, default='') Constraing of slurm job", + "TIMEOUT_MIN": "(int, default=72 * 60) Minimum amount of minutes to timeout", + "CPUS_PER_TASK": "(int, default=8) Numer of cpus per task.", + "MEM_GB": "(int, default=32) Amount of RAM to request from slurm", + "ADDITIONAL_PARAMETERS": "(Dict[[str, Any]], default={}) Any default slurm options to pass to submitit" + } +} diff --git a/dev/launch_benchmark_suite_scheduler_slurm.sh b/dev/launch_benchmark_suite_scheduler_slurm.sh new file mode 100644 index 000000000..8e26703de --- /dev/null +++ b/dev/launch_benchmark_suite_scheduler_slurm.sh @@ -0,0 +1,39 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. + +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# This benchmark suite script launches a benchmark suite scheduler slurm job. +# The job takes an absolute json config path (see benchmark_suite_scheduler_template.json for info) +# The job continuously monitors training benchmarks, and dynamically launches evaluation jobs +# and amalgamates the results. + +######################### EXAMPLE USAGE ################################# + +# cd into vissl root directory. +# +# bash ./dev/launch_benchmark_suite_scheduler_slurm.sh /path/to/benchmark_suite_scheduler.json + +# See benchmark_suite_scheduler_template.json or for config information or slurm_evaluator.py for class structure. +######################### INPUT PARAMS ################################## + +FILE=( "$@" ) + +####################### setup experiment dir ################################### + +# create a temporary experiment folder to run the SLURM job in isolation +RUN_ID=$(date +'%Y-%m-%d-%H-%M-%S') +EXP_ROOT_DIR="/checkpoint/$USER/vissl/$RUN_ID" + +echo "EXP_ROOT_DIR: $EXP_ROOT_DIR" +echo "CONFIG_FILE: ${FILE[0]}" + +rm -rf "$EXP_ROOT_DIR" +mkdir -p "$EXP_ROOT_DIR" +cp -r . "$EXP_ROOT_DIR" + +####################### setup experiment dir ################################### +export PYTHONPATH="$EXP_ROOT_DIR/:$PYTHONPATH" +python -u "$EXP_ROOT_DIR/tools/launch_benchmark_suite_scheduler_slurm.py" \ + "${FILE[@]}" diff --git a/tools/launch_benchmark_suite_scheduler_slurm.py b/tools/launch_benchmark_suite_scheduler_slurm.py new file mode 100644 index 000000000..8fcc0574d --- /dev/null +++ b/tools/launch_benchmark_suite_scheduler_slurm.py @@ -0,0 +1,107 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys + +import pkg_resources +import submitit +from fvcore.common.file_io import PathManager +from vissl.config.attr_dict import AttrDict +from vissl.utils.benchmark_suite_scheduler import BenchmarkSuiteScheduler +from vissl.utils.hydra_config import is_hydra_available +from vissl.utils.io import load_file +from vissl.utils.misc import recursive_dict_merge +from vissl.utils.slurm import is_submitit_available + + +# Default config options +default_config_file = pkg_resources.resource_filename( + "dev", "benchmark_suite/benchmark_suite_scheduler_defaults.json" +) +_DEFAULT_CONFIG = load_file(default_config_file) + + +class SlurmEvaluatorJob: + """ + The slurm evaluator job is a thin wrapper around BenchmarkSuiteScheduler + used by submitit. It's main function is to run multiple evaluations + on a single training. + """ + + def __init__(self, benchmark_suite_scheduler: BenchmarkSuiteScheduler): + self.benchmark_suite_scheduler = benchmark_suite_scheduler + + def __call__(self): + self.benchmark_suite_scheduler.evaluate() + + def checkpoint(self): + """ + This method is called whenever a job is pre-empted, timedout, etc,. + Here we save the evaluation benchmarks, so that we can reload them + and continue where we left off. + """ + self.benchmark_suite_scheduler.save_evaluation_benchmarks() + # Forces the benchmark_suite_scheduler to automatically reload it's + # checkpoint, the benchmark results. + self.benchmark_suite_scheduler.autoload_benchmark_suite_scheduler_checkpoint = ( + True + ) + + trainer = SlurmEvaluatorJob( + benchmark_suite_scheduler=self.benchmark_suite_scheduler + ) + return submitit.helpers.DelayedSubmission(trainer) + + +def launch_benchmark_suite_scheduler(config_file): + assert PathManager.exists(config_file), "Slurm evaluator config file must exist" + + user_config = load_file(config_file) + config = _DEFAULT_CONFIG.copy() + recursive_dict_merge(config, user_config) + + benchmark_suite_scheduler = BenchmarkSuiteScheduler(**config["params"]) + benchmark_suite_scheduler_job = SlurmEvaluatorJob( + benchmark_suite_scheduler=benchmark_suite_scheduler + ) + executor = submitit.AutoExecutor(folder=benchmark_suite_scheduler.evaluation_dir()) + + assert "slurm_options" in config, "slurm_options must be specified" + assert ( + "PARTITION" in config["slurm_options"] + ), "slurm_options.PARTITION is a required field to launch the benchmark suite on slurm" + + slurm_options = AttrDict(config["slurm_options"]) + executor.update_parameters( + name=slurm_options.NAME, + slurm_comment=slurm_options.COMMENT, + slurm_partition=slurm_options.PARTITION, + slurm_constraint=slurm_options.CONSTRAINT, + timeout_min=slurm_options.TIMEOUT_MIN, + nodes=1, + cpus_per_task=slurm_options.CPUS_PER_TASK, + tasks_per_node=1, + mem_gb=slurm_options.MEM_GB, + slurm_additional_parameters=slurm_options.ADDITIONAL_PARAMETERS, + ) + + job = executor.submit(benchmark_suite_scheduler_job) + print(f"SUBMITTED EVALUATION JOB: {job.job_id}") + + +if __name__ == "__main__": + """ + Example usage: + python -u "./vissl/engines/benchmark_suite_scheduler.py" \ + "/path/to/benchmark_suite_scheduler_example.json" + """ + assert is_hydra_available(), "Make sure to install hydra" + + assert ( + is_submitit_available() + ), "Please 'pip install submitit' to schedule jobs on SLURM" + + config_file = sys.argv[1] + launch_benchmark_suite_scheduler(config_file) diff --git a/vissl/hooks/log_hooks.py b/vissl/hooks/log_hooks.py index ffecf7968..2fa181901 100644 --- a/vissl/hooks/log_hooks.py +++ b/vissl/hooks/log_hooks.py @@ -245,6 +245,11 @@ def on_update(self, task: "tasks.ClassyTask") -> None: "eta": eta_string, "peak_mem(M)": peak_mem_used, } + + if iteration == 1: + # Set max iterations. Currently used in benchmark_suite_scheduler.py + log_data["max_iterations"] = task.max_iteration + if self.btime_freq and len(batch_times) >= self.btime_freq: rolling_avg_time = ( sum(batch_times[-self.btime_freq :]) / self.btime_freq diff --git a/vissl/utils/benchmark_suite_scheduler.py b/vissl/utils/benchmark_suite_scheduler.py new file mode 100644 index 000000000..2b8bbcdb0 --- /dev/null +++ b/vissl/utils/benchmark_suite_scheduler.py @@ -0,0 +1,618 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import collections +import json +import logging +import os +import time +from datetime import datetime +from pathlib import Path +from typing import List + +import submitit +from fvcore.common.file_io import PathManager +from hydra.experimental import compose, initialize_config_module +from vissl.config.attr_dict import AttrDict +from vissl.utils.distributed_launcher import launch_distributed_on_slurm +from vissl.utils.hydra_config import convert_to_attrdict +from vissl.utils.io import load_file, makedir +from vissl.utils.misc import flatten_dict, retry + + +""" +This class is designed to be used to run multiple evaluations on a single (pre)training. +Using the #evaluate method we continuously monitor training checkpoints, launch evaluations +dynamically as they become available, and amalgamate the evaluation results as they become +available. + +For SLURM usage, you should create a JSON configuration file +(see benchmark_suite_scheduler_template.json) and use +launch_benchmark_suite_scheduler_slurm.sh for convenience. +""" + +_DEFAULT_PYTORCH_PORTS = [40050] +# How many times to retry a slurm job submission. +_NUM_SLURM_RETRIES = 5 +# How many seconds to sleep between iterations of the main loop. +_SLEEP_TIME_SECONDS = 15 +# Slurm states marked as terminal. SlurmEvulator#evaluate will finish +# once all jobs are in a terminal state. +_SLURM_JOB_TERMINAL_STATES = [ + "BOOT_FAIL", + "CANCELLED", + "COMPLETED", + "DEADLINE", + "FAILED", + "NODE_FAIL", + "OUT_OF_MEMORY", + "REVOKED", + "SPECIAL_EXIT", + "STOPPED", + "SUSPENDED", + "TIMEOUT", +] +# Wait for the training checkpoint folder to be available for 1 hour. +_TRAINING_CONFIG_WAIT_SECONDS = 60 * 60 + + +class BenchmarkSuiteScheduler: + """ + The Slurm Evaluator is a class designed to continuously monitor VISSL pretrainings + and launch evaluations as checkpoints become available. The method takes a + config dictionary consisting of the training checkpoint directory, an array of + benchmarks, and information on how often to evaluate the trainings. + """ + + def __init__( + self, + training_checkpoint_dir: str, + benchmarks: List, + evaluate_final_phase=True, + evaluation_phase_freq: int = -1, + evaluation_iter_freq: int = -1, + autoload_slurm_evaluator_checkpoint=False, + slurm_evaluator_checkpoint: str = None, + retry_evaluation_job_ids: List[int] = None, + auto_retry_evaluations=False, + max_retries=3, + pytorch_ports=None, + ): + """ + Args: + + training_checkpoint_dir: (str). Checkpoint directory of the training. + This should match the trainings CHECKPOINT.dir + benchmarks: (list[dict]) Benchmarks with the following structure: + "config_files": [ + { + # Path to config file. + "config=test/integration_test/quick_eval_in1k_linear.yaml" + # Config overrides. + "config.TRAIN.DATA_LIMIT=1000", + ... + }, + ... + ] + evaluate_final_phase: (bool, optional, default=True). Whether or not to evaluate the + final phase of the training. + evaluation_phase_freq: (int, optional, default=-1) How often to evaluate phases. + Training checkpoint phase freq must evenly + divide evaluation_phase_freq. + evaluation_iter_freq: (int, optional, default=-1) How often to evaluate iterations. + Training checkpoint iteration freq must evenly + divide evaluation_iter_freq. + autoload_slurm_evaluator_checkpoint: (bool, optional, default=False) Whether or not to + autoload slurm_evaluator Checkpoint. + This is useful when slurm evaluator job + is preempted for example. + slurm_evaluator_checkpoint: (str, optional, default=None) String of + slurm_evaluator checkpoint directory. + retry_evaluation_job_ids: (List[int], optional, default=[]) List of job_ids to retry. + auto_retry_evaluations: (bool, optional, default=False) Whether or not to automatically + retry all failed jobs. + max_retries: (int, optional, default=3). Maximum number of retries. + pytorch_ports: (list[int], optional, default=[40050]). Ports to cycle through as + you are launching your trainings. + """ + self.evaluation_jobs_finished = set() + + # Required Arguments + self.training_checkpoint_dir = training_checkpoint_dir + self.training_checkpoint_file = os.path.join( + self.training_checkpoint_dir, "train_config.yaml" + ) + self.benchmarks = benchmarks + + # Optional Arguments + self.evaluate_final_phase = evaluate_final_phase + self.evaluation_phase_freq = evaluation_phase_freq + self.evaluation_iter_freq = evaluation_iter_freq + self.autoload_slurm_evaluator_checkpoint = autoload_slurm_evaluator_checkpoint + self.slurm_evaluator_checkpoint = slurm_evaluator_checkpoint + self.retry_evaluation_job_ids = retry_evaluation_job_ids or [] + self.auto_retry_evaluations = auto_retry_evaluations + self.max_retries = max_retries + self.pytorch_ports = pytorch_ports or _DEFAULT_PYTORCH_PORTS + self.pytorch_ports_iterable = iter(self.pytorch_ports) + + self.validate() + + # Will be set in #evaluate, once training_checkpoint_dir becomes available. + self.training_config = None + self.evaluation_results = None + + def evaluate(self): + """ + Evaluate the checkpoints. At a high level, this is the structure. + + 1. Load training YAML config file. + 2. Monitor training checkpoints. When checkpoint is ready, launch the evaluation jobs. + 3. Monitor the evaluation jobs. When evaluation jobs are complete, mark the results. + 3. Once all jobs have been recorded, complete. + """ + start_time = time.time() + + # Wait for the training config to be available. This indicates the training has begun. + while True: + if time.time() - start_time > _TRAINING_CONFIG_WAIT_SECONDS: + raise RuntimeError( + f"Training config still doesn't exist after:" + f"{_TRAINING_CONFIG_WAIT_SECONDS / 60} minutes" + ) + + if ( + PathManager.exists(self.training_checkpoint_file) + and self._max_training_iterations() + ): + # Load training yaml config. + self._load_training_config() + + # Set max training iterations + self.max_training_iterations = self._max_training_iterations() + + # Generate evaluation results + self.evaluation_results = self._generate_initial_benchmark_results() + self._validate_evaluation_setup() + + break + + time.sleep(_SLEEP_TIME_SECONDS) + + # Save initial evaluation benchmarks, for checkpointing reasons. + self.save_evaluation_benchmarks() + + # Checkpoint folder is now available. Continuously monitor the training checkpoints, + # launch evaluation jobs as needed, monitor their progress, and record their results. + while True: + self._evaluate_checkpoints() + + self._check_evaluation_jobs() + + # Break if no more checkpoints to evaluate + if self._finished(): + logging.info("Evaluations are finished") + break + + time.sleep(_SLEEP_TIME_SECONDS) + + def _max_training_iterations(self): + """ + Get the max number of training iterations for the main SSL training. + """ + training_stdout_json_file = os.path.join( + self.training_checkpoint_dir, "stdout.json" + ) + + # If the stdout.json path doesn't exist, return None. + if not PathManager.exists(training_stdout_json_file): + return None + + with PathManager.open(training_stdout_json_file, "rb") as f: + # First line of stdout.json must have max_iterations in the first line + try: + first_json_line = json.loads(next(f)) + assert ( + "max_iterations" in first_json_line + ), "Training must set max_iterations in the stoud.json. See LogLossLrEtaHook." + return first_json_line["max_iterations"] + except StopIteration: + return None + + def save_evaluation_benchmarks(self): + """ + Create the /evaluations directory inside the training checkpoints dir. + Upload json file to the parent evaluation directories, as well as + to each child evaluation directories. + """ + # Upload all checkpoints evaluations to parent checkpoint directory. + evaluation_dir = self.evaluation_dir() + parent_metrics_file = os.path.join(evaluation_dir, "evaluation_metrics.json") + + makedir(evaluation_dir) + + self._write_json_file(self.evaluation_results, parent_metrics_file) + + # Upload each checkpoint's evaluations to child directories. + for checkpoint_str, benchmarks in self.evaluation_results.items(): + child_metrics_dir = os.path.join(evaluation_dir, checkpoint_str) + child_metrics_file = os.path.join( + child_metrics_dir, "evaluation_metrics.json" + ) + + makedir(child_metrics_dir) + + self._write_json_file(benchmarks, child_metrics_file) + + logging.info("Saved benchmarks json file.") + + def evaluation_dir(self): + return os.path.join(self.training_checkpoint_dir, "evaluations") + + def _load_training_config(self): + # Load training yaml config. + self.training_config = load_file(self.training_checkpoint_file) + self.training_config = AttrDict(self.training_config) + + logging.info( + f"Loaded training checkpoint config from: { self.training_checkpoint_file }" + ) + + def validate(self): + """ + Validate the class instance is valid. + """ + assert not ( + self.autoload_slurm_evaluator_checkpoint and self.slurm_evaluator_checkpoint + ), "Specify only one of autoload_slurm_evaluator_checkpoint and slurm_evaluator_checkpoint" # NOQA + assert ( + type(self.evaluation_iter_freq) is int and self.evaluation_iter_freq >= -1 + ), "The evaluation_iter_freq must be an int >= 1" + assert ( + type(self.evaluation_phase_freq) is int and self.evaluation_phase_freq >= -1 + ), "The evaluation_phase_freq must be an int >= 1" + assert ( + self.evaluation_iter_freq >= -1 + or self.evaluation_phase_freq >= -1 + or self.evaluate_final_phase + ), "Please specify evaluation_iter_freq, evaluation_phase_freq, or evaluate_final_phase" # NOQA + assert ( + type(self.max_retries) is int and self.max_retries >= -1 + ), "Max retries must be >= -1." + + def _validate_evaluation_setup(self): + if self.evaluation_iter_freq > -1: + assert ( + self.evaluation_iter_freq + % self.training_config.CHECKPOINT.CHECKPOINT_ITER_FREQUENCY + ) == 0, "Evaluation iter frequency must evenly divide the checkpoint iter frequency" # NOQA + + if self.evaluation_phase_freq > -1: + assert ( + self.evaluation_phase_freq + % self.training_config.CHECKPOINT.CHECKPOINT_FREQUENCY + ) == 0, "Evaluation phase frequency must evenly divide the checkpoint phase frequency" # NOQA + + assert PathManager.exists( + self.training_config.SLURM.LOG_FOLDER + ), "Training slurm log folder must exist" + assert PathManager.exists( + self.training_config.CHECKPOINT.DIR + ), "Training slurm checkpoint folder must exist" + + def _finished(self): + # Count total number of evaluation jobs. + total_jobs = 0 + for benchmarks in self.evaluation_results.values(): + total_jobs += len(benchmarks) + + return len(self.evaluation_jobs_finished) == total_jobs + + def _evaluate_checkpoints(self): + for checkpoint_str, benchmarks in self.evaluation_results.items(): + # TODO: Can we possible retrieve this from CheckpointWriter, to consolidate logic. + checkpoint_str = os.path.join( + self.training_config.CHECKPOINT.DIR, f"{ checkpoint_str }.torch" + ) + if PathManager.exists(checkpoint_str): + self._evaluate_checkpoint(checkpoint_str, benchmarks) + + def _evaluate_checkpoint(self, checkpoint_str, benchmarks): + for benchmark in benchmarks: + retry_job = self._retry_job(benchmark) + if benchmark["job_id"] and not retry_job: + continue + + if retry_job: + self.evaluation_jobs_finished.remove(benchmark["job_id"]) + # Log the job retry. + job_id, slurm_state = benchmark["job_id"], benchmark["slurm_state"] + logging.info(f"Retrying job: { job_id } in state: { slurm_state }") + + args, config = self._generate_config(benchmark["config_files"]) + job = self._launch_slurm_job(args, config) + + time.sleep(10) # Wait for slurm job status to be reliably updated. + + # Set checkpoint status + benchmark["job_id"] = job.job_id + benchmark["num_retries"] += 1 + benchmark["slurm_log_dir"] = config.SLURM.LOG_FOLDER + benchmark["slurm_checkpoint_dir"] = config.CHECKPOINT.DIR + benchmark[ + "weights_init_params_file" + ] = config.MODEL.WEIGHTS_INIT.PARAMS_FILE + benchmark["slurm_state"] = job.state + + current_time = datetime.now().strftime("%H:%M:%S %z") + log = f""" + Launched Slurm Evaluation job. Time: { current_time } + job_id: { job.job_id }, num_retries: { benchmark["num_retries"] } + evaluation_name: { benchmark["evaluation_name"] } + checkpoint_str: { checkpoint_str } + state_prev: None, state_current: { job.state } + """ + + logging.info(log) + + # Save evaluation results to json file. + self.save_evaluation_benchmarks() + + def _retry_job(self, benchmark): + return benchmark["job_id"] in self.retry_evaluation_job_ids or ( + benchmark["slurm_state"] in _SLURM_JOB_TERMINAL_STATES + and benchmark["slurm_state"] != "COMPLETED" + and self.auto_retry_evaluations + and benchmark["num_retries"] < self.max_retries + ) + + @retry(n_tries=_NUM_SLURM_RETRIES) + def _launch_slurm_job(self, args, config): + # Get next port in the list of #pytorch_ports + try: + port = next(self.pytorch_ports_iterable) + except StopIteration: + # Start at the beginning of the ports list. + self.pytorch_ports_iterable = iter(self.pytorch_ports) + port = next(self.pytorch_ports_iterable) + + config.SLURM.PORT_ID = port + + return launch_distributed_on_slurm(engine_name=args.engine_name, cfg=config) + + def _write_json_file(self, data, file_name): + with PathManager.open(file_name, "w") as fopen: + fopen.write(json.dumps(data, sort_keys=True)) + fopen.flush() + + def _check_evaluation_jobs(self): + # Monitor each evaluation job, change slurm job state as needed, and + # load results if finished. + for benchmarks in self.evaluation_results.values(): + for benchmark in benchmarks: + self._monitor_benchmark_job(benchmark) + + def _monitor_benchmark_job(self, benchmark): + if not benchmark["job_id"]: + return # Do nothing, the job has not yet started. + + # Create SlurmJob object. + job_id = str(benchmark["job_id"]) + folder = Path(benchmark["slurm_log_dir"]) + job = submitit.SlurmJob(job_id=job_id, folder=folder, tasks=[0]) + + if job.state in _SLURM_JOB_TERMINAL_STATES: + # Job is in terminal state, mark job as finished. + self.evaluation_jobs_finished.add(job.job_id) + + if job.state != benchmark["slurm_state"]: + # Job state has changed, log transition, and update state in json file. + checkpoint_str = os.path.split(benchmark["weights_init_params_file"])[-1] + + current_time = datetime.now().strftime("%H:%M:%S %z") + log = f""" + Slurm Evaluation job changed states. Time: { current_time } + job_id: { job.job_id }, num_retries: { benchmark["num_retries"] } + evaluation_name: { benchmark["evaluation_name"] }, + checkpoint_str: { checkpoint_str }, + state_prev: { benchmark["slurm_state"] }, state_curr: { job.state } + """ + + logging.info(log) + # Benchmark Job state has changed. Update the benchmark state. + self._update_benchmark_state(benchmark, job) + self.save_evaluation_benchmarks() + + def _update_benchmark_state(self, benchmark, job): + # Job state has changed, record it. + benchmark["slurm_state"] = job.state + + if job.done(): + # Upload metrics files. + benchmark["metrics"] = self._get_benchmark_metrics(benchmark) + + def _get_benchmark_metrics(self, benchmark): + metrics_file = os.path.join(benchmark["slurm_checkpoint_dir"], "metrics.json") + + if PathManager.exists(metrics_file): + # Open metrics file from finished evaluation job. + metrics = [] + with PathManager.open(metrics_file, "rb") as f: + for line in f: + metrics.append(json.loads(line)) + + final_metrics = collections.defaultdict(lambda: {"metric": -1}) + + self._set_largest_metric(metrics, final_metrics) + + result = dict(final_metrics) + else: + result = """Evaluation Job has completed, but metrics.json is not available. + Please check the evaluation's checkpoint_dir.""" + + return result + + def _set_largest_metric(self, metrics, final_metrics): + # Get the largest metrics over all recorded metrics. + for m in metrics: + flattened_metrics = flatten_dict(m) + for metric_name, metric in flattened_metrics.items(): + if metric_name in ["iteration", "phase_idx", "train_phase_idx"]: + continue # These are not evaluation metrics + + if metric > final_metrics[metric_name]["metric"]: + final_metrics[metric_name]["metric"] = metric + final_metrics[metric_name]["iteration"] = flattened_metrics[ + "iteration" + ] + final_metrics[metric_name]["train_phase_idx"] = flattened_metrics[ + "train_phase_idx" + ] + + def _generate_initial_benchmark_results(self): + default_checkpoint = os.path.join( + self.evaluation_dir(), "evaluation_metrics.json" + ) + autoload_slurm_evaluator_checkpoint = ( + self.autoload_slurm_evaluator_checkpoint + and PathManager.exists(default_checkpoint) + ) + + if autoload_slurm_evaluator_checkpoint or self.slurm_evaluator_checkpoint: + return self._load_evaluation_results_checkpoint() + + evaluation_configs = {} + + for benchmark in self.benchmarks: + default_evaluation_name = os.path.split(benchmark["config_files"][0])[-1] + evaluation_name = ( + benchmark.get("evaluation_name") or default_evaluation_name + ) + + last_phase = self.training_config.OPTIMIZER.num_epochs - 1 + + # TODO: Can we retrieve this from CheckpointWriter? + if self.evaluate_final_phase: + # Evaluate Last phase checkpoint + training_checkpoint = f"model_final_checkpoint_phase{ last_phase }" + self._set_initial_benchmark_result( + benchmark, training_checkpoint, evaluation_name, evaluation_configs + ) + + if self.evaluation_phase_freq > -1: + # Evaluate every "evaluation_phase_freq" phase checkpoint. + evaluate_epochs = range(self.evaluation_phase_freq, last_phase)[ + :: self.evaluation_phase_freq + ] + for epoch in evaluate_epochs: + training_checkpoint = f"model_phase{epoch}" + self._set_initial_benchmark_result( + benchmark, + training_checkpoint, + evaluation_name, + evaluation_configs, + ) + + if self.evaluation_iter_freq > -1: + # Evaluate every "evaluation_iter_freq" iteration checkpoints. + evaluate_iterations = range( + self.evaluation_iter_freq, self.max_training_iterations + )[:: self.evaluation_iter_freq] + for iteration in evaluate_iterations: + training_checkpoint = f"model_iteration{iteration}" + self._set_initial_benchmark_result( + benchmark, + training_checkpoint, + evaluation_name, + evaluation_configs, + ) + + return evaluation_configs + + def _load_evaluation_results_checkpoint(self): + default_checkpoint = os.path.join( + self.evaluation_dir(), "evaluation_metrics.json" + ) + checkpoint_file = ( + default_checkpoint + if self.autoload_slurm_evaluator_checkpoint + else self.slurm_evaluator_checkpoint + ) + + evaluation_config = load_file(checkpoint_file) + + logging.info(f"Loaded evaluation results checkpoint from: { checkpoint_file }") + + return evaluation_config + + def _set_initial_benchmark_result( + self, benchmark, training_checkpoint, evaluation_name, evaluation_configs + ): + """ + Generates evaluation configs in order to evaluate the final output of the + pretraining model specified in 'training_config'. + """ + log_dir = self._evaluation_log_dir(training_checkpoint, evaluation_name) + checkpoint_dir = self._evaluation_checkpoint_dir( + training_checkpoint, evaluation_name + ) + + evaluation_configs[training_checkpoint] = ( + evaluation_configs.get(training_checkpoint) or [] + ) + + # Add benchmark result information + benchmark_result = { + "evaluation_name": evaluation_name, + "job_id": None, + "num_retries": 0, + "slurm_log_dir": log_dir, + "checkpoint_dir": checkpoint_dir, + "metrics": None, + "slurm_state": None, + "config_files": benchmark["config_files"].copy(), + } + + # Hydra config information + weights_init_path = os.path.join( + self.training_config.CHECKPOINT.DIR, f"{ training_checkpoint }.torch" + ) + + # Override certain options for slurm + for option in [ + f"config.MODEL.WEIGHTS_INIT.PARAMS_FILE='{weights_init_path}'", + "config.SLURM.USE_SLURM=true", + f"config.SLURM.LOG_FOLDER='{log_dir}'", + f"config.CHECKPOINT.DIR='{checkpoint_dir}'", + f"hydra.run.dir='{ log_dir }'", + ]: + benchmark_result["config_files"].insert(1, option) + + evaluation_configs[training_checkpoint].append(benchmark_result) + + def _evaluation_log_dir(self, evaluation_directory, evaluation_name): + """ + Directory to put logs for an evaluation job. + """ + evaluation_dir = self.evaluation_dir() + return os.path.join(evaluation_dir, evaluation_directory, evaluation_name) + + def _evaluation_checkpoint_dir(self, model_final_checkpoint, evaluation_name): + """ + Directory to put checkpoints in for an evaluation job. + """ + return os.path.join( + self._evaluation_log_dir(model_final_checkpoint, evaluation_name), + "checkpoints", + ) + + def _generate_config(self, config): + """ + Generate AttrDict config from a config YAML file and overrides. + """ + with initialize_config_module(config_module="vissl.config"): + config = compose("defaults", overrides=config) + + return convert_to_attrdict(config) diff --git a/vissl/utils/distributed_launcher.py b/vissl/utils/distributed_launcher.py index 04ac1f0d8..82811fdaf 100644 --- a/vissl/utils/distributed_launcher.py +++ b/vissl/utils/distributed_launcher.py @@ -275,3 +275,5 @@ def launch_distributed_on_slurm(cfg: AttrDict, engine_name: str): trainer = _ResumableSlurmJob(engine_name=engine_name, config=cfg) job = executor.submit(trainer) print(f"SUBMITTED: {job.job_id}") + + return job diff --git a/vissl/utils/io.py b/vissl/utils/io.py index 6a7271561..982db7c63 100644 --- a/vissl/utils/io.py +++ b/vissl/utils/io.py @@ -113,7 +113,10 @@ def load_file(filename, mmap_mode=None): data = np.load(fopen, encoding="latin1") elif file_ext == ".json": with PathManager.open(filename, "r") as fopen: - data = json.loads(fopen) + data = json.load(fopen) + elif file_ext == ".yaml": + with PathManager.open(filename, "r") as fopen: + data = yaml.load(fopen, Loader=yaml.FullLoader) else: raise Exception(f"Reading from {file_ext} is not supported yet") return data diff --git a/vissl/utils/misc.py b/vissl/utils/misc.py index ed91d1544..7447fb1f0 100644 --- a/vissl/utils/misc.py +++ b/vissl/utils/misc.py @@ -3,10 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import collections import logging import os import random import tempfile +import time +from functools import partial, wraps import numpy as np import pkg_resources @@ -308,3 +311,113 @@ def __enter__(self): def __exit__(self, *exc): set_rng_state(self.rng_state) + + +# Credit: https://stackoverflow.com/questions/42521549/retry-function-in-python +def retry(func=None, exception=Exception, n_tries=5, delay=5, backoff=1, logger=False): + """Retry decorator with exponential backoff. + + Parameters + ---------- + func : typing.Callable, optional + Callable on which the decorator is applied, by default None + exception : Exception or tuple of Exceptions, optional + Exception(s) that invoke retry, by default Exception + n_tries : int, optional + Number of tries before giving up, by default 5 + delay : int, optional + Initial delay between retries in seconds, by default 5 + backoff : int, optional + Backoff multiplier e.g. value of 2 will double the delay, by default 1 + logger : bool, optional + Option to log or print, by default False + + Returns + ------- + typing.Callable + Decorated callable that calls itself when exception(s) occur. + + Examples + -------- + >>> import random + >>> @retry(exception=Exception, n_tries=4) + ... def test_random(text): + ... x = random.random() + ... if x < 0.5: + ... raise Exception("Fail") + ... else: + ... print("Success: ", text) + >>> test_random("It works!") + """ + + if func is None: + return partial( + retry, + exception=exception, + n_tries=n_tries, + delay=delay, + backoff=backoff, + logger=logger, + ) + + @wraps(func) + def wrapper(*args, **kwargs): + ntries, ndelay = n_tries, delay + + while ntries > 1: + try: + return func(*args, **kwargs) + except exception as e: + msg = f"{str(e)}, Retrying in {ndelay} seconds..." + if logger: + logging.warning(msg) + else: + print(msg) + time.sleep(ndelay) + ntries -= 1 + ndelay *= backoff + + return func(*args, **kwargs) + + return wrapper + + +# Credit: https://stackoverflow.com/questions/6027558/flatten-nested-dictionaries-compressing-keys # NOQA +def flatten_dict(d: dict, parent_key="", sep="_"): + """ + Flattens a dict, delimited with a '_'. For example the input: + { + 'top_1': { + 'res_5': 100 + } + } + + will return: + + { + 'top_1_res_5': 100 + } + """ + items = [] + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, collections.MutableMapping): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +# Credit: https://stackoverflow.com/questions/7204805/how-to-merge-dictionaries-of-dictionaries +def recursive_dict_merge(dict1, dict2): + """ + Recursively merges dict2 into dict1 + """ + if not isinstance(dict1, dict) or not isinstance(dict2, dict): + return dict2 + for k in dict2: + if k in dict1: + dict1[k] = recursive_dict_merge(dict1[k], dict2[k]) + else: + dict1[k] = dict2[k] + return dict1