Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions amp_rsl_rl/runners/amp_on_policy_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"):
delta_t,
self.cfg["slow_down_factor"],
amp_joint_names,
6.0,
)

# self.env.unwrapped.scene["robot"].joint_names)
Expand Down
98 changes: 92 additions & 6 deletions amp_rsl_rl/utils/motion_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
# SPDX-License-Identifier: BSD-3-Clause

from pathlib import Path
from typing import List, Union, Tuple, Generator
from typing import List, Union, Tuple, Generator, Optional
from dataclasses import dataclass

import torch
import numpy as np
from scipy.spatial.transform import Rotation, Slerp
from scipy.interpolate import interp1d
from scipy.signal import butter, filtfilt # ← added for acausal low-pass filtering


def download_amp_dataset_from_hf(
Expand Down Expand Up @@ -88,8 +89,15 @@ class MotionData:

def __post_init__(self) -> None:
# Convert numpy arrays (or SciPy Rotations) to torch tensors

def to_tensor(x):
return torch.tensor(x, device=self.device, dtype=torch.float32)
# Ensure positive strides / contiguous memory before converting to torch
if not isinstance(x, np.ndarray):
x = np.array(x, dtype=np.float32)
# Make contiguous and float32 without copying if already OK
x = np.ascontiguousarray(x, dtype=np.float32)
return torch.from_numpy(x).to(self.device)


if isinstance(self.joint_positions, np.ndarray):
self.joint_positions = to_tensor(self.joint_positions)
Expand Down Expand Up @@ -185,6 +193,8 @@ class AMPLoader:
simulation_dt: Timestep used by the simulator
slow_down_factor: Integer factor to slow down original data
expected_joint_names: (Optional) override for joint ordering
vel_filter_cutoff_hz: (Optional) if set, apply a 2nd-order acausal (zero-phase)
low-pass Butterworth filter with this cutoff (Hz) to joint and base velocities.
"""

def __init__(
Expand All @@ -196,8 +206,10 @@ def __init__(
simulation_dt: float,
slow_down_factor: int,
expected_joint_names: Union[List[str], None] = None,
vel_filter_cutoff_hz: Optional[float] = None,
) -> None:
self.device = device
self.vel_filter_cutoff_hz = vel_filter_cutoff_hz
if isinstance(dataset_path_root, str):
dataset_path_root = Path(dataset_path_root)

Expand Down Expand Up @@ -302,10 +314,64 @@ def _compute_ang_vel(

return np.vstack((rotvec, rotvec[-1]))

def _compute_raw_derivative(self, data: np.ndarray, dt: float) -> np.ndarray:
d = (data[1:] - data[:-1]) / dt
def _compute_raw_derivative(
self, data: np.ndarray, dt: float, angular: bool = False
) -> np.ndarray:
"""
Finite-difference derivative with optional angle-wrap handling.

If `angular` is True, the difference is computed as the minimal wrapped
angle between consecutive samples using atan2(sin Δ, cos Δ). This
properly handles jumps across ±π.
"""
if data.shape[0] < 2:
return np.zeros_like(data)

if angular:
# Minimal angular difference in [-π, π] per element
delta = np.arctan2(
np.sin(data[1:] - data[:-1]),
np.cos(data[1:] - data[:-1]),
)
else:
delta = data[1:] - data[:-1]

d = delta / dt
return np.vstack([d, d[-1:]])

def _lowpass_acausal(
self, data: np.ndarray, dt: float, cutoff_hz: Optional[float], order: int = 2
) -> np.ndarray:
"""
Zero-phase (acausal) low-pass filter using Butterworth + filtfilt.

Args:
data: array shaped (T, D) or (T,)
dt: sampling interval
cutoff_hz: cutoff frequency in Hz. If None or invalid, returns data unchanged.
order: filter order (default 2)

Returns:
Filtered data with same shape.
"""
if cutoff_hz is None or cutoff_hz <= 0:
return data
fs = 1.0 / dt
nyq = 0.5 * fs
Wn = cutoff_hz / nyq
# If cutoff is above Nyquist, skip filtering to avoid warnings/instability
if Wn >= 1.0:
return data

b, a = butter(N=order, Wn=Wn, btype="low", analog=False)
# Ensure 2D for consistent axis handling
data_2d = data if data.ndim == 2 else data[:, None]
# filtfilt is along axis=0 (time)
filtered = filtfilt(b, a, data_2d, axis=0, method="pad")
filtered = np.ascontiguousarray(filtered) # ensure positive strides
return filtered if data.ndim == 2 else filtered[:, 0]


def load_data(
self,
dataset_path: Path,
Expand Down Expand Up @@ -346,8 +412,9 @@ def load_data(
t_new = np.linspace(0, T * dt, T_new)

resampled_joint_positions = self._resample_data_Rn(jp_list, t_orig, t_new)
# ── Joint velocities with wrapped-angle difference to handle ±π jumps ──
resampled_joint_velocities = self._compute_raw_derivative(
resampled_joint_positions, simulation_dt
resampled_joint_positions, simulation_dt, angular=True
)

resampled_base_positions = self._resample_data_Rn(
Expand All @@ -358,7 +425,7 @@ def load_data(
)

resampled_base_lin_vel_mixed = self._compute_raw_derivative(
resampled_base_positions, simulation_dt
resampled_base_positions, simulation_dt, angular=False
)

resampled_base_ang_vel_mixed = self._compute_ang_vel(
Expand All @@ -377,6 +444,25 @@ def load_data(
resampled_base_orientations, simulation_dt, local=True
)

# ── Optional 2nd-order acausal low-pass filtering of velocities ──
if self.vel_filter_cutoff_hz is not None:
c = self.vel_filter_cutoff_hz
resampled_joint_velocities = self._lowpass_acausal(
resampled_joint_velocities, simulation_dt, c, order=2
)
resampled_base_lin_vel_mixed = self._lowpass_acausal(
resampled_base_lin_vel_mixed, simulation_dt, c, order=2
)
resampled_base_ang_vel_mixed = self._lowpass_acausal(
resampled_base_ang_vel_mixed, simulation_dt, c, order=2
)
resampled_base_lin_vel_local = self._lowpass_acausal(
resampled_base_lin_vel_local, simulation_dt, c, order=2
)
resampled_base_ang_vel_local = self._lowpass_acausal(
resampled_base_ang_vel_local, simulation_dt, c, order=2
)

return MotionData(
joint_positions=resampled_joint_positions,
joint_velocities=resampled_joint_velocities,
Expand Down