Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions docs/source/rtc.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ policy_cfg = PI0Config()
policy_cfg.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10, # How many steps to blend with previous chunk
max_guidance_weight=10.0, # How strongly to enforce consistency
prefix_attention_schedule=RTCAttentionSchedule.EXP, # Exponential blend
)

Expand Down Expand Up @@ -101,7 +100,10 @@ Typical values: 8-12 steps
RTCConfig(execution_horizon=10)
```

**`max_guidance_weight`**: How strongly to enforce consistency with the previous chunk. This is a hyperparameter that can be tuned to balance the smoothness of the transitions and the reactivity of the policy. For 10 steps flow matching (SmolVLA, Pi0, Pi0.5), a value of 10.0 is a optimal value.
**`max_guidance_weight`**: How strongly to enforce consistency with the previous chunk. This is a hyperparameter that can be tuned to balance the smoothness of the transitions and the reactivity of the policy.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a clipping parameter, no the actual guidance weight. You might modify the sentence to say include something like "a clipping parameter on the computed guidance weight. Ensures stability."


If `max_guidance_weight` is not set, the number of flow matching steps will be used as max guidance weight.
Check the following paper - https://alexander-soare.github.io/robotics/2025/08/05/smooth-as-butter-robot-policies.html

**`prefix_attention_schedule`**: How to weight consistency across the overlap region.

Expand All @@ -112,6 +114,14 @@ RTCConfig(execution_horizon=10)

**`inference_delay`**: How many timesteps of inference latency your system has. This is passed to `predict_action_chunk()` rather than the config, since it may vary at runtime.

**`sigma_d`**: The variance of the prior distribution. This is a hyperparameter that can be tuned to balance the smoothness of the transitions and the reactivity of the policy.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: sigma is not "variance", but "standard deviation" rather.


Typical values: 0.1-1.0

By default, `sigma_d` is set to 1.0. So it's following the original RTC paper. But you can tune it to your needs, by reducing it to get more reactivity and by increasing it to get more smoothness.

Check the following paper - https://alexander-soare.github.io/robotics/2025/08/05/smooth-as-butter-robot-policies.html

## Testing RTC Offline

Before running on a real robot, test RTC with dataset samples to visualize how it works:
Expand All @@ -121,7 +131,6 @@ python examples/rtc/eval_dataset.py \
--policy.path=lerobot/pi0_libero_finetuned \
--dataset.repo_id=HuggingFaceVLA/libero \
--rtc.execution_horizon=10 \
--rtc.max_guidance_weight=10.0 \
--device=cuda
```

Expand Down
115 changes: 104 additions & 11 deletions examples/rtc/eval_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@
"""
Evaluate Real-Time Chunking (RTC) performance on dataset samples.

This script takes two random samples from a dataset:
This script takes two samples from a dataset:
- Uses actions from the first sample as previous chunk
- Generates new actions for the second sample with and without RTC

Sampling modes:
- Random (default): Two independent random samples
- Correlated (--sample_correlation_shift): Second sample is shifted from first by N steps
to test temporal correlation and sigma effects

It compares action predictions with and without RTC on dataset samples,
measuring consistency and ground truth alignment.

Expand All @@ -31,17 +36,30 @@
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=mps \
--rtc.max_guidance_weight=10.0 \
--rtc.prefix_attention_schedule=EXP \
--random_chunks=true \
--seed=10

uv run python examples/rtc/eval_dataset.py \
--policy.path=lerobot/pi05_libero_finetuned \
--rtc.max_guidance_weight=11 \
--dataset.repo_id=HuggingFaceVLA/libero \
--rtc.execution_horizon=10 \
--device=mps \
--seed=10 \
--random_chunks=true \
--rtc.sigma_d=1

# Basic usage with pi0.5 policy
uv run python examples/rtc/eval_dataset.py \
--policy.path=lerobot/pi05_libero_finetuned \
--dataset.repo_id=HuggingFaceVLA/libero \
--rtc.execution_horizon=10 \
--device=mps
--seed=10
--device=mps \
--seed=10 \
--sample_correlation_shift=10 \
--rtc.sigma_d=1.0 \
--rtc.full_trajectory_alignment=true

# Basic usage with pi0.5 policy with cuda device
uv run python examples/rtc/eval_dataset.py \
Expand All @@ -63,6 +81,16 @@
--rtc.execution_horizon=8 \
--device=cuda

# With sample correlation shift to test temporal correlation (sigma effect)
# Second sample is taken as first_sample_index + shift
uv run python examples/rtc/eval_dataset.py \
--policy.path=lerobot/pi05_libero_finetuned \
--dataset.repo_id=HuggingFaceVLA/libero \
--rtc.execution_horizon=10 \
--device=mps \
--sample_correlation_shift=5 \
--seed=10

# With torch.compile for faster inference (PyTorch 2.0+)
# Note: CUDA graphs disabled by default due to in-place ops in denoising loop
uv run python examples/rtc/eval_dataset.py \
Expand Down Expand Up @@ -161,7 +189,6 @@ class RTCEvalConfig(HubMixin):
default_factory=lambda: RTCConfig(
enabled=True,
execution_horizon=20,
max_guidance_weight=10.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=True,
debug_maxlen=1000,
Expand Down Expand Up @@ -191,6 +218,11 @@ class RTCEvalConfig(HubMixin):
metadata={"help": "Inference delay for RTC"},
)

num_inference_steps: int | None = field(
default=None,
metadata={"help": "Number of flow matching inference steps. If None, uses policy default."},
)

# Torch compile configuration
use_torch_compile: bool = field(
default=False,
Expand All @@ -215,6 +247,22 @@ class RTCEvalConfig(HubMixin):
},
)

next_inference_after: int = field(
default=10,
metadata={
"help": "How many steps after the previous "
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
},
)

random_chunks: bool = field(
default=False,
metadata={
"help": "The shift between the two chunks to be evaluated. It's used to check bigger difference between previons action chunk"
"and newly generated chunk."
},
)

def __post_init__(self):
# Parse policy path
policy_path = parser.get_path_arg("policy")
Expand Down Expand Up @@ -303,6 +351,17 @@ def _init_policy(self, name: str, rtc_enabled: bool, rtc_debug: bool):
if self.cfg.policy.type == "pi05" or self.cfg.policy.type == "pi0":
config.compile_model = self.cfg.use_torch_compile

# Override number of flow matching steps if specified
if self.cfg.num_inference_steps is not None:
if self.cfg.policy.type == "smolvla":
config.num_steps = self.cfg.num_inference_steps
logging.info(f" Overriding num_steps for SmolVLA: {self.cfg.num_inference_steps}")
elif self.cfg.policy.type in ["pi0", "pi05"]:
config.num_inference_steps = self.cfg.num_inference_steps
logging.info(
f" Overriding num_inference_steps for {self.cfg.policy.type}: {self.cfg.num_inference_steps}"
)

policy = policy_class.from_pretrained(self.cfg.policy.pretrained_path, config=config)
policy = policy.to(self.device)
policy.eval()
Expand All @@ -315,6 +374,8 @@ def _init_policy(self, name: str, rtc_enabled: bool, rtc_debug: bool):
prefix_attention_schedule=self.cfg.rtc.prefix_attention_schedule,
debug=rtc_debug,
debug_maxlen=self.cfg.rtc.debug_maxlen,
full_trajectory_alignment=self.cfg.rtc.full_trajectory_alignment,
sigma_d=self.cfg.rtc.sigma_d,
)
policy.config.rtc_config = rtc_config
policy.init_rtc_processor()
Expand Down Expand Up @@ -433,13 +494,45 @@ def run_evaluation(self):
logging.info("=" * 80)
logging.info("Starting RTC evaluation")
logging.info(f"Inference delay: {self.cfg.inference_delay}")
if self.cfg.num_inference_steps is not None:
logging.info(f"Number of flow matching steps: {self.cfg.num_inference_steps}")
else:
logging.info("Number of flow matching steps: Using policy default")
logging.info("=" * 80)

# Load two random samples from dataset
data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True)
loader_iter = iter(data_loader)
first_sample = next(loader_iter)
second_sample = next(loader_iter)
# Correlated sampling: second sample is shifted from first
shift = self.cfg.next_inference_after
logging.info(f"Using correlated sampling: second sample shifted by {shift} from first sample")

# Get random first index
first_idx = random.randint(0, len(self.dataset) - 1)

# Calculate second index with shift, ensuring it's within bounds
second_idx = first_idx + shift

if self.cfg.random_chunks:
second_idx = random.randint(first_idx + 1, len(self.dataset) - 1)

if second_idx < 0 or second_idx >= len(self.dataset):
raise ValueError(
f"Second sample index {second_idx} is out of bounds [0, {len(self.dataset) - 1}]. "
f"First index: {first_idx}, shift: {shift}. "
f"Please use a smaller shift value or adjust the seed."
)

logging.info(f"First sample index: {first_idx}, Second sample index: {second_idx}")

# Get samples directly from dataset
first_sample = self.dataset[first_idx]
second_sample = self.dataset[second_idx]

# Add batch dimension (dataset returns unbatched samples)
first_sample = {
k: v.unsqueeze(0) if isinstance(v, torch.Tensor) else v for k, v in first_sample.items()
}
second_sample = {
k: v.unsqueeze(0) if isinstance(v, torch.Tensor) else v for k, v in second_sample.items()
}

preprocessed_first_sample = self.preprocessor(first_sample)
preprocessed_second_sample = self.preprocessor(second_sample)
Expand All @@ -461,7 +554,7 @@ def run_evaluation(self):
with torch.no_grad():
prev_chunk_left_over = policy_prev_chunk_policy.predict_action_chunk(
preprocessed_first_sample,
)[:, :25, :].squeeze(0)
)[:, shift : shift + 25, :].squeeze(0)
logging.info(f" Generated prev_chunk shape: {prev_chunk_left_over.shape}")

# Destroy policy_prev_chunk to free memory for large models
Expand Down
3 changes: 1 addition & 2 deletions examples/rtc/eval_with_real_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,7 @@ class RTCDemoConfig(HubMixin):
# RTC configuration
rtc: RTCConfig = field(
default_factory=lambda: RTCConfig(
execution_horizon=10,
max_guidance_weight=1.0,
execution_horizon=15,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
)
)
Expand Down
1 change: 1 addition & 0 deletions src/lerobot/policies/pi0/modeling_pi0.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,7 @@ def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
time=time,
original_denoise_step_partial=denoise_step_partial_call,
execution_horizon=execution_horizon,
num_flow_matching_steps=num_steps,
)
else:
v_t = denoise_step_partial_call(x_t)
Expand Down
1 change: 1 addition & 0 deletions src/lerobot/policies/pi05/modeling_pi05.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,7 @@ def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
time=time,
original_denoise_step_partial=denoise_step_partial_call,
execution_horizon=execution_horizon,
num_flow_matching_steps=num_steps,
)
else:
v_t = denoise_step_partial_call(x_t)
Expand Down
18 changes: 16 additions & 2 deletions src/lerobot/policies/rtc/configuration_rtc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,30 @@ class RTCConfig:
# Core RTC settings
# Todo change to exp
prefix_attention_schedule: RTCAttentionSchedule = RTCAttentionSchedule.LINEAR
max_guidance_weight: float = 10.0

# This parameter is used to clip the guidance weight
# In the original RTC it's a hyperparameter that can be tuned to balance the smoothness of the transitions and the reactivity of the policy.
# The original paper had value of 5.0, during the implementation it was found that this parameter is not needed and can be replaced with the number of steps.
# Check the following paper - https://alexander-soare.github.io/robotics/2025/08/05/smooth-as-butter-robot-policies.html
# num of steps could be used as clipping parameter without requirements on hyperparameters tuning
# If user doesn't provide this parameter, than the number of flow matching steps will be used as max guidance weight
max_guidance_weight: float | None = None
execution_horizon: int = 10

# This parameter is used to clip the variance of the prior distribution
# Check the following paper - https://alexander-soare.github.io/robotics/2025/08/05/smooth-as-butter-robot-policies.html
# The value could be in range of [0, 1], if it's 1.0, than the behavior is the same as the original RTC
sigma_d: float = 1.0

full_trajectory_alignment: bool = False

# Debug settings
debug: bool = False
debug_maxlen: int = 100

def __post_init__(self):
"""Validate RTC configuration parameters."""
if self.max_guidance_weight <= 0:
if self.max_guidance_weight is not None and self.max_guidance_weight <= 0:
raise ValueError(f"max_guidance_weight must be positive, got {self.max_guidance_weight}")
if self.debug_maxlen <= 0:
raise ValueError(f"debug_maxlen must be positive, got {self.debug_maxlen}")
37 changes: 32 additions & 5 deletions src/lerobot/policies/rtc/modeling_rtc.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def denoise_step(
inference_delay,
time,
original_denoise_step_partial,
num_flow_matching_steps,
execution_horizon=None,
) -> Tensor:
"""RTC guidance wrapper around an existing denoiser.
Expand All @@ -138,6 +139,9 @@ def denoise_step(
broadcastable with ``x_t``.
original_denoise_step_partial (Callable[[Tensor], Tensor]): Callable that
computes the base denoised velocity given only ``x_t``.
num_flow_matching_steps (int): Number of flow matching inference steps (must be positive integer).
If ``max_guidance_weight`` is ``None``, will be used as the max guidance weight
(Alex Soare optimization).
execution_horizon (int | None): Horizon used to build prefix weights. If
``None``, defaults to ``self.rtc_config.execution_horizon``.

Expand All @@ -153,6 +157,10 @@ def denoise_step(
- Guidance correction is computed via autograd using ``x1_t = x_t + time * v_t`` and
``error = (prev_chunk_left_over - x1_t) * weights``.
- The final guidance weight is clamped by ``max_guidance_weight`` from the config.
- Alex Soare optimization: If ``max_guidance_weight`` is ``None``,
``max_guidance_weight`` is automatically set to ``num_flow_matching_steps``
without requiring hyperparameter tuning.
Reference: https://alexander-soare.github.io/robotics/2025/08/05/smooth-as-butter-robot-policies.html

Reference:
https://www.physicalintelligence.company/download/real_time_chunking.pdf
Expand Down Expand Up @@ -209,18 +217,37 @@ def denoise_step(
)

with torch.enable_grad():
v_t = original_denoise_step_partial(x_t)
x_t.requires_grad_(True)
v_t = original_denoise_step_partial(x_t)

x1_t = x_t - time * v_t # noqa: N806
err = (prev_chunk_left_over - x1_t) * weights
grad_outputs = err.clone().detach()
correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0]

max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight)
correction = err

# If full trajectory alignment is enabled this is not default RTC behavior,
# the newly generated trajectory will be fully aligned with the previous chunk. It's similar to the case where we ignore gradients from
# from the neural network, and take into the account only the error between the previous chunk and the newly generated trajectory.
# It will work faster and if the distance between chunks generation is not so high than it gives smoother transitions.
if not self.rtc_config.full_trajectory_alignment:
grad_outputs = err.clone().detach()
correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0]

# Alex Soare optimization: Use num_flow_matching_steps as max_guidance_weight if not set
# Reference: https://alexander-soare.github.io/robotics/2025/08/05/smooth-as-butter-robot-policies.html
# The number of flow matching steps can be used as a clipping parameter without hyperparameter tuning
max_guidance_weight = self.rtc_config.max_guidance_weight
if max_guidance_weight is None:
max_guidance_weight = num_flow_matching_steps

max_guidance_weight = torch.as_tensor(max_guidance_weight)

tau_tensor = torch.as_tensor(tau)
squared_one_minus_tau = (1 - tau_tensor) ** 2
inv_r2 = (squared_one_minus_tau + tau_tensor**2) / (squared_one_minus_tau)
prior_variance = torch.as_tensor(self.rtc_config.sigma_d**2)
inv_r2 = (squared_one_minus_tau + tau_tensor**2 * prior_variance) / (
squared_one_minus_tau * prior_variance
)
c = torch.nan_to_num((1 - tau_tensor) / tau_tensor, posinf=max_guidance_weight)
guidance_weight = torch.nan_to_num(c * inv_r2, posinf=max_guidance_weight)
guidance_weight = torch.minimum(guidance_weight, max_guidance_weight)
Expand Down
1 change: 1 addition & 0 deletions src/lerobot/policies/smolvla/modeling_smolvla.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,7 @@ def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
time=time,
original_denoise_step_partial=denoise_step_partial_call,
execution_horizon=execution_horizon,
num_flow_matching_steps=self.config.num_steps,
)
else:
v_t = denoise_step_partial_call(x_t)
Expand Down
Loading