From fa3c65976eddda2ad4242bad7ebda31b9507cc40 Mon Sep 17 00:00:00 2001 From: MiaoDX Date: Mon, 19 Jan 2026 20:30:53 +0800 Subject: [PATCH] fix(g1_env): apply gravity compensation torques in queue_action queue_action() previously always sent body_tau=zeros, ignoring the enable_gravity_compensation flag. Now computes and applies gravity feedforward torques when enabled, improving arm tracking accuracy. Also adds unit test to prevent regression. --- gr00t_wbc/control/envs/g1/g1_env.py | 11 ++++- tests/control/envs/__init__.py | 0 tests/control/envs/g1/__init__.py | 0 tests/control/envs/g1/test_g1_env.py | 70 ++++++++++++++++++++++++++++ 4 files changed, 80 insertions(+), 1 deletion(-) create mode 100644 tests/control/envs/__init__.py create mode 100644 tests/control/envs/g1/__init__.py create mode 100644 tests/control/envs/g1/test_g1_env.py diff --git a/gr00t_wbc/control/envs/g1/g1_env.py b/gr00t_wbc/control/envs/g1/g1_env.py index f687060..7a7b200 100644 --- a/gr00t_wbc/control/envs/g1/g1_env.py +++ b/gr00t_wbc/control/envs/g1/g1_env.py @@ -222,11 +222,20 @@ def queue_action(self, action: Dict[str, any]): # Map action from joint order to actuator order body_actuator_q = self.robot_model.get_body_actuated_joints(action["q"]) + # Compute gravity compensation torques if enabled + body_tau = np.zeros_like(body_actuator_q) + if self.enable_gravity_compensation and self.last_obs is not None: + current_q = self.last_obs["q"] + gravity_torques = self.robot_model.compute_gravity_compensation_torques( + current_q, joint_groups=self.gravity_compensation_joints + ) + body_tau = self.robot_model.get_body_actuated_joints(gravity_torques) + self.body().queue_action( { "body_q": body_actuator_q, "body_dq": np.zeros_like(body_actuator_q), - "body_tau": np.zeros_like(body_actuator_q), + "body_tau": body_tau, } ) diff --git a/tests/control/envs/__init__.py b/tests/control/envs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/control/envs/g1/__init__.py b/tests/control/envs/g1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/control/envs/g1/test_g1_env.py b/tests/control/envs/g1/test_g1_env.py new file mode 100644 index 0000000..6150c13 --- /dev/null +++ b/tests/control/envs/g1/test_g1_env.py @@ -0,0 +1,70 @@ +""" +Integration tests for G1Env. + +Tests the gravity compensation fix from commit c28acee. +""" + +import numpy as np +import pytest +from unittest.mock import MagicMock + +from gr00t_wbc.control.robot_model.instantiation import instantiate_g1_robot_model + + +class TestG1EnvGravityCompensation: + """ + Integration test for G1Env.queue_action() gravity compensation. + + This test WILL FAIL if the fix is reverted (old code always sent body_tau=zeros). + """ + + @pytest.mark.parametrize("enable_gravity,expect_nonzero", [ + (True, True), # Gravity enabled → expect non-zero torques + (False, False), # Gravity disabled → expect zero torques + ]) + def test_queue_action_gravity_compensation(self, enable_gravity, expect_nonzero): + """ + Test that G1Env.queue_action() respects the gravity compensation flag. + """ + from gr00t_wbc.control.envs.g1.g1_env import G1Env + + robot_model = instantiate_g1_robot_model(waist_location="lower_body") + + # Create minimal mock G1Env with real queue_action method + env = object.__new__(G1Env) + env.robot_model = robot_model + env.enable_gravity_compensation = enable_gravity + env.gravity_compensation_joints = ["arms"] + env.with_hands = False # Skip hand commands in queue_action + + # Set up last_obs with arm configuration + arm_indices = robot_model.get_joint_group_indices("arms") + q = np.zeros(robot_model.num_dofs) + for idx in arm_indices: + q[idx] = 0.3 + env.last_obs = {"q": q} + + # Mock safety_monitor + mock_safety = MagicMock() + mock_safety.handle_violations = lambda obs, action: {"action": action} + env.safety_monitor = mock_safety + + # Capture what queue_action sends + captured_actions = [] + mock_body = MagicMock() + mock_body.queue_action = lambda a: captured_actions.append(a) + env.body = lambda: mock_body + + # Call actual queue_action + env.queue_action({"q": np.zeros(robot_model.num_dofs)}) + + # Verify + assert len(captured_actions) == 1 + body_tau = captured_actions[0]["body_tau"] + + if expect_nonzero: + assert not np.allclose(body_tau, 0), \ + "BUG: body_tau is zeros when gravity compensation is enabled!" + else: + assert np.allclose(body_tau, 0), \ + "body_tau should be zeros when gravity compensation is disabled"