Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions alf/algorithms/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self,
num_eval_episodes=10,
num_eval_environments: int = 1,
async_eval: bool = True,
shared_train_eval_env: bool = False,
save_checkpoint_for_best_eval: Optional[Callable] = None,
ddp_paras_check_interval: int = 0,
num_summaries=None,
Expand Down Expand Up @@ -227,6 +228,10 @@ def __init__(self,
num_eval_environments: the number of environments for evaluation.
async_eval: whether to do evaluation asynchronously in a different
process. Note that this may use more memory.
shared_train_eval_env: whether the training and evaluation environments are
shared. If True, the environment instance used in training will also be
used in evaluation. This is useful for cases such as rl-in-real with a
single physical environment.
save_checkpoint_for_best_eval: If provided, will be called with a list of
evaluation metrics. If it returns True, a checkpoint will be saved.
A possible value of this option is `alf.trainers.evaluator.BestEvalChecker()`,
Expand Down Expand Up @@ -383,6 +388,7 @@ def __init__(self,
self.num_eval_episodes = num_eval_episodes
self.num_eval_environments = num_eval_environments
self.async_eval = async_eval
self.shared_train_eval_env = shared_train_eval_env
self.save_checkpoint_for_best_eval = save_checkpoint_for_best_eval
self.ddp_paras_check_interval = ddp_paras_check_interval
self.num_summaries = num_summaries
Expand Down
13 changes: 9 additions & 4 deletions alf/trainers/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,15 @@ def __init__(self, config: TrainerConfig, conf_file: str):
pre_configs, num_envs, config.root_dir, seed))
self._worker.start()
else:
self._env = create_environment(
for_evaluation=True,
num_parallel_environments=num_envs,
seed=seed)
if config.shared_train_eval_env:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to set the step_type in the replay buffer just before evaluation started to StepType.LAST.

Copy link
Contributor

@emailweixu emailweixu Mar 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need to set the next step type for training to FIRST

assert not self._async, "should not use async_eval in shared_train_eval_env mode"
self._env = alf.get_env()
self._env.reset()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert async_eval = False?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Added assertion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commit not pushed to the right remote?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, yes, that was what happened ... now pushed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happened to me also. It's hard to remember, especially now that we don't change alf that often. We can probably remove the other remote.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, why do we call env.reset() here but not in the other branch. Maybe add a comment in the code?

else:
self._env = create_environment(
for_evaluation=True,
num_parallel_environments=num_envs,
seed=seed)
self._evaluator = SyncEvaluator(self._env, config)

def eval(self, algorithm: RLAlgorithm, step_metric_values: Dict[str, int]):
Expand Down