Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 123 additions & 18 deletions alf/algorithms/rl_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import os
import time
import torch
from typing import Callable, Optional
from typing import Callable, List, Optional, Tuple
from absl import logging

import alf
Expand Down Expand Up @@ -544,6 +544,7 @@ def _async_unroll(self, unroll_length: int):
store_exp_time = 0.
step_time = 0.
max_step_time = 0.
effective_unroll_steps = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we lack a formal definition of "effective" in the code document.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more comments with examples, especially in preprocess_unroll_experience

qsize = self._async_unroller.get_queue_size()
unroll_results = self._async_unroller.gather_unroll_results(
unroll_length, self._config.max_unroll_length)
Expand All @@ -566,10 +567,12 @@ def _async_unroll(self, unroll_length: int):
step_time += unroll_result.step_time
max_step_time = max(max_step_time, unroll_result.step_time)

store_exp_time += self._process_unroll_step(
store_exp_time_i, effective_unroll_steps_i = self._process_unroll_step(
policy_step, policy_step.output, time_step,
transformed_time_step, policy_state, experience_list,
original_reward_list)
store_exp_time += store_exp_time_i
effective_unroll_steps += effective_unroll_steps_i

alf.summary.scalar("time/unroll_env_step",
env_step_time,
Expand All @@ -596,20 +599,105 @@ def _async_unroll(self, unroll_length: int):

self._current_transform_state = common.detach(trans_state)

return experience
# if the input unroll_length is 0 (e.g. fractional unroll), then this it treated as
# an effective unroll iter
effective_unroll_iters = 1 if unroll_length == 0 else effective_unroll_steps // unroll_length
return experience, effective_unroll_iters

def preprocess_unroll_experience(
self, rollout_info, step_type: StepType,
experiences: Experience) -> Tuple[List, float]:
"""A function for processing the experience obtained from an unroll step before
being saved into the replay buffer. By default, it returns the input
experience unmodified. Users can customize this function in the derived
class to achieve different effects. For example:
- per-step processing: return the current step of experience unmodified (by default)
or a modified version according to the customized ``preprocess_unroll_experience``.
As another example, task filtering can be simply achieved by returning ``[]``
for that particular task.
- per-episode processing: this can be achieved by returning a list of processed
experiences. For example, this can be used for success episode labeling.

Args:
rollout_info: the rollout info.
step_type: the step type of the current experience.
experiences: one step of experience.

Returns:
- ``effective_experiences``: a list of experiences. Users can customize this
functions in the derived class to achieve different effects. For example:
* return a list that contains only the input experience (default behavior).
* return a list that contains a number of experiences. This can be useful
for episode processing such as success episode labeling.
- ``effective_unroll_steps`` : a value representing the effective number of
unroll steps per env. The default value of 1, meaning the length of
effective experience is 1 after calling ``preprocess_unroll_experience``,
the same as the input length of experience.
The value of ``effective_unroll_steps`` can be set differently according
to different scenarios, e.g.:
(1) per-step saving without delay: saving each step of unroll experience
into the replay buffer as we get it. Set ``effective_unroll_steps``
as 1 so that each step will be counted as valid and there will be no
impact on the train/unroll ratio.
(2) all-step saving with delay: saving all the steps of unroll experience into
the replay buffer with delay. This can happen in the case where we want to
annotate an trajectory based on some quantities that are not immediately
available in the current step (e.g. task success/failure). In this case,
we can simply caching the experiences and set ``effective_experiences=[]``
before obtaining the quantities required for annotation.
After obtaining the quantities required for annotation, we can
set ``effective_experiences`` as the cached and annotated experience.
To maintain the original unroll/train iter ratio, we can set
``effective_unroll_steps=1``, meaning each unroll step is regarded as
effective in terms of the unroll/train iter ratio, even though the
pace of saving the unroll steps into replay buffer has been altered.
(3) selective saving: exclude some of the unroll experiences and only save
the rest. This could be useful in the case where there are transitions
that are irrelevant to the training (e.g. in the multi-task case, where
we want to exclude data from certain subtasks).
This can be achieved by setting ``effective_experiences=[]``for the
steps to be excluded, while ``effective_experiences = [experiences]``
otherwise. If we do not want to trigger a train iter for the unroll
step that will be excluded, we can simply set ``effective_unroll_steps=0``.
Otherwise, we can simply set ``effective_unroll_steps=1``.
(4) parallel environments: in the case of parallel environments, the value
of ``effective_unroll_steps`` can be set according to the modes described
above and the status of each environment (e.g. ``effective_unroll_steps``
can be set to an average value across environments). Note that this could
resulf to a floating number.
"""
effective_experiences = [experiences]
effective_unroll_steps = 1
return effective_experiences, effective_unroll_steps

def _process_unroll_step(self, policy_step, action, time_step,
transformed_time_step, policy_state,
experience_list, original_reward_list):
experience_list,
original_reward_list) -> Tuple[int, float]:
"""

Returns:
- ``store_exp_time``: the time spent on storing the experience
- ``effective_unroll_steps``: a value representing the effective number
of unroll steps per env. The default value of 1, meaning the length of
effective experience is 1 after calling ``preprocess_unroll_experience``,
the same as the input length of experience. For more details on it,
please refer to the docstr of ``preprocess_unroll_experience``.
"""
self.observe_for_metrics(time_step.cpu())
exp = make_experience(time_step.cpu(),
alf.layers.to_float32(policy_step),
alf.layers.to_float32(policy_state))

effective_unroll_steps = 1
store_exp_time = 0
if not self.on_policy:
# 1) pre-process unroll experience
pre_processed_exp_list, effective_unroll_steps = self.preprocess_unroll_experience(
policy_step.info, time_step.step_type, exp)
# 2) observe
t0 = time.time()
self.observe_for_replay(exp)
for exp in pre_processed_exp_list:
self.observe_for_replay(exp)
store_exp_time = time.time() - t0

exp_for_training = Experience(
Expand All @@ -620,7 +708,7 @@ def _process_unroll_step(self, policy_step, action, time_step,

experience_list.append(exp_for_training)
original_reward_list.append(time_step.reward)
return store_exp_time
return store_exp_time, effective_unroll_steps

def reset_state(self):
"""Reset the state of the algorithm.
Expand All @@ -644,6 +732,8 @@ def _sync_unroll(self, unroll_length: int):
Returns:
Experience: The stacked experience with shape :math:`[T, B, \ldots]`
for each of its members.
effective_unroll_iters: the effective number of unroll iterations.
Each unroll iteration contains ``unroll_length`` unroll steps.
"""
if self._current_time_step is None:
self._current_time_step = common.get_initial_time_step(self._env)
Expand All @@ -665,6 +755,7 @@ def _sync_unroll(self, unroll_length: int):
policy_step_time = 0.
env_step_time = 0.
store_exp_time = 0.
effective_unroll_steps = 0
for _ in range(unroll_length):
policy_state = common.reset_state_if_necessary(
policy_state, initial_state, time_step.is_first())
Expand Down Expand Up @@ -693,9 +784,11 @@ def _sync_unroll(self, unroll_length: int):
if self._overwrite_policy_output:
policy_step = policy_step._replace(
output=next_time_step.prev_action)
store_exp_time += self._process_unroll_step(
store_exp_time_i, effective_unroll_steps_i = self._process_unroll_step(
policy_step, action, time_step, transformed_time_step,
policy_state, experience_list, original_reward_list)
store_exp_time += store_exp_time_i
effective_unroll_steps += effective_unroll_steps_i

time_step = next_time_step
policy_state = policy_step.state
Expand Down Expand Up @@ -723,7 +816,12 @@ def _sync_unroll(self, unroll_length: int):
self._current_policy_state = common.detach(policy_state)
self._current_transform_state = common.detach(trans_state)

return experience
# if the input unroll_length is 0 (e.g. fractional unroll), then this it treated as
# an effective unroll iter.
# one ``effective_unroll_iter`` refers to the ``unroll_length`` times of calling
# of ``rollout_step`` in the unroll phase.
effective_unroll_iters = 1 if unroll_length == 0 else effective_unroll_steps // unroll_length
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's strange to call unroll "iter"? The original definition is that each training iter we have one unroll. So what does unroll iters mean in this context?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comments. One effective_unroll_iter refers to the unroll_length times of calling of rollout_step in the unroll phase.

return experience, effective_unroll_iters

def train_iter(self):
"""Perform one iteration of training.
Expand All @@ -747,7 +845,7 @@ def _compute_train_info_and_loss_info_on_policy(self, unroll_length):
with record_time("time/unroll"):
with torch.cuda.amp.autocast(self._config.enable_amp,
dtype=self._config.amp_dtype):
experience = self.unroll(self._config.unroll_length)
experience, _ = self.unroll(self._config.unroll_length)
self.summarize_metrics()

train_info = experience.rollout_info
Expand Down Expand Up @@ -788,6 +886,9 @@ def _unroll_iter_off_policy(self):
unroll length, it may not have been called.
- root_inputs: root-level time step returned by the unroll
- rollout_info: rollout info returned by the unroll
- effective_unroll_iters: the effective number of unroll iterations.
``train_from_replay_buffer`` will be run ``effective_unroll_iters`` times
during ``_train_iter_off_policy``.
"""
config: TrainerConfig = self._config

Expand All @@ -804,6 +905,7 @@ def _unroll_iter_off_policy(self):
unrolled = False
root_inputs = None
rollout_info = None
effective_unroll_iters = 0
if (alf.summary.get_global_counter()
>= self._rl_train_after_update_steps
and (unroll_length > 0 or config.unroll_length == 0) and
Expand All @@ -822,19 +924,21 @@ def _unroll_iter_off_policy(self):
# need to remember whether summary has been written between
# two unrolls.
with self._ensure_rollout_summary:
experience = self.unroll(unroll_length)
experience, effective_unroll_iters = self.unroll(
unroll_length)
if experience:
self.summarize_rollout(experience)
self.summarize_metrics()
rollout_info = experience.rollout_info
if config.use_root_inputs_for_after_train_iter:
root_inputs = experience.time_step
del experience
return unrolled, root_inputs, rollout_info
return unrolled, root_inputs, rollout_info, effective_unroll_iters

def _train_iter_off_policy(self):
"""User may override this for their own training procedure."""
unrolled, root_inputs, rollout_info = self._unroll_iter_off_policy()
unrolled, root_inputs, rollout_info, effective_unroll_iters = self._unroll_iter_off_policy(
)

# replay buffer may not have been created for two different reasons:
# 1. in online RL training (``has_offline`` is False), unroll is not
Expand All @@ -846,11 +950,12 @@ def _train_iter_off_policy(self):
return 0

self.train()
steps = self.train_from_replay_buffer(update_global_counter=True)

if unrolled:
with record_time("time/after_train_iter"):
self.after_train_iter(root_inputs, rollout_info)
steps = 0
for i in range(effective_unroll_iters):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's possible the effective_unroll_iters is always smaller than 1 in the case of num_envs > 1.

Copy link
Contributor Author

@Haichao-Zhang Haichao-Zhang May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Now also handles the fractional unroll case.

steps += self.train_from_replay_buffer(update_global_counter=True)
if unrolled:
with record_time("time/after_train_iter"):
self.after_train_iter(root_inputs, rollout_info)

# For now, we only return the steps of the primary algorithm's training
return steps
Expand Down
Loading