diff --git a/alf/algorithms/config.py b/alf/algorithms/config.py index 2b8f94e4e..9f2a3dc37 100644 --- a/alf/algorithms/config.py +++ b/alf/algorithms/config.py @@ -53,6 +53,7 @@ def __init__(self, num_eval_episodes=10, num_eval_environments: int = 1, async_eval: bool = True, + shared_train_eval_env: bool = False, save_checkpoint_for_best_eval: Optional[Callable] = None, ddp_paras_check_interval: int = 0, num_summaries=None, @@ -227,6 +228,10 @@ def __init__(self, num_eval_environments: the number of environments for evaluation. async_eval: whether to do evaluation asynchronously in a different process. Note that this may use more memory. + shared_train_eval_env: whether the training and evaluation environments are + shared. If True, the environment instance used in training will also be + used in evaluation. This is useful for cases such as rl-in-real with a + single physical environment. save_checkpoint_for_best_eval: If provided, will be called with a list of evaluation metrics. If it returns True, a checkpoint will be saved. A possible value of this option is `alf.trainers.evaluator.BestEvalChecker()`, @@ -383,6 +388,7 @@ def __init__(self, self.num_eval_episodes = num_eval_episodes self.num_eval_environments = num_eval_environments self.async_eval = async_eval + self.shared_train_eval_env = shared_train_eval_env self.save_checkpoint_for_best_eval = save_checkpoint_for_best_eval self.ddp_paras_check_interval = ddp_paras_check_interval self.num_summaries = num_summaries diff --git a/alf/trainers/evaluator.py b/alf/trainers/evaluator.py index dd6367900..897b7f7e1 100644 --- a/alf/trainers/evaluator.py +++ b/alf/trainers/evaluator.py @@ -74,10 +74,15 @@ def __init__(self, config: TrainerConfig, conf_file: str): pre_configs, num_envs, config.root_dir, seed)) self._worker.start() else: - self._env = create_environment( - for_evaluation=True, - num_parallel_environments=num_envs, - seed=seed) + if config.shared_train_eval_env: + assert not self._async, "should not use async_eval in shared_train_eval_env mode" + self._env = alf.get_env() + self._env.reset() + else: + self._env = create_environment( + for_evaluation=True, + num_parallel_environments=num_envs, + seed=seed) self._evaluator = SyncEvaluator(self._env, config) def eval(self, algorithm: RLAlgorithm, step_metric_values: Dict[str, int]):