|
20 | 20 |
|
21 | 21 | import tensordict.tensordict
|
22 | 22 | import torch
|
| 23 | +from tensordict.nn import WrapModule |
23 | 24 |
|
24 | 25 | from torchrl.collectors import MultiSyncDataCollector
|
25 | 26 |
|
|
106 | 107 | CenterCrop,
|
107 | 108 | ClipTransform,
|
108 | 109 | Compose,
|
| 110 | + ConditionalPolicySwitch, |
109 | 111 | Crop,
|
110 | 112 | DeviceCastTransform,
|
111 | 113 | DiscreteActionProjection,
|
@@ -13271,6 +13273,206 @@ def test_composite_reward_spec(self) -> None:
|
13271 | 13273 | assert transform.transform_reward_spec(reward_spec) == expected_reward_spec
|
13272 | 13274 |
|
13273 | 13275 |
|
| 13276 | +class TestConditionalPolicySwitch(TransformBase): |
| 13277 | + def test_single_trans_env_check(self): |
| 13278 | + base_env = CountingEnv(max_steps=15) |
| 13279 | + condition = lambda td: ((td.get("step_count") % 2) == 0).all() |
| 13280 | + # Player 0 |
| 13281 | + policy_odd = lambda td: td.set("action", env.action_spec.zero()) |
| 13282 | + policy_even = lambda td: td.set("action", env.action_spec.one()) |
| 13283 | + transforms = Compose( |
| 13284 | + StepCounter(), |
| 13285 | + ConditionalPolicySwitch(condition=condition, policy=policy_even), |
| 13286 | + ) |
| 13287 | + env = base_env.append_transform(transforms) |
| 13288 | + env.check_env_specs() |
| 13289 | + |
| 13290 | + def _create_policy_odd(self, base_env): |
| 13291 | + return WrapModule( |
| 13292 | + lambda td, base_env=base_env: td.set( |
| 13293 | + "action", base_env.action_spec_unbatched.zero(td.shape) |
| 13294 | + ), |
| 13295 | + out_keys=["action"], |
| 13296 | + ) |
| 13297 | + |
| 13298 | + def _create_policy_even(self, base_env): |
| 13299 | + return WrapModule( |
| 13300 | + lambda td, base_env=base_env: td.set( |
| 13301 | + "action", base_env.action_spec_unbatched.one(td.shape) |
| 13302 | + ), |
| 13303 | + out_keys=["action"], |
| 13304 | + ) |
| 13305 | + |
| 13306 | + def _create_transforms(self, condition, policy_even): |
| 13307 | + return Compose( |
| 13308 | + StepCounter(), |
| 13309 | + ConditionalPolicySwitch(condition=condition, policy=policy_even), |
| 13310 | + ) |
| 13311 | + |
| 13312 | + def _make_env(self, max_count, env_cls): |
| 13313 | + torch.manual_seed(0) |
| 13314 | + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) |
| 13315 | + base_env = env_cls(max_steps=max_count) |
| 13316 | + policy_even = self._create_policy_even(base_env) |
| 13317 | + transforms = self._create_transforms(condition, policy_even) |
| 13318 | + return base_env.append_transform(transforms) |
| 13319 | + |
| 13320 | + def _test_env(self, env, policy_odd): |
| 13321 | + env.check_env_specs() |
| 13322 | + env.set_seed(0) |
| 13323 | + r = env.rollout(100, policy_odd, break_when_any_done=False) |
| 13324 | + # Check results are independent: one reset / step in one env should not impact results in another |
| 13325 | + r0, r1, r2 = r.unbind(0) |
| 13326 | + r0_split = r0.split(6) |
| 13327 | + assert all(((r == r0_split[0][: r.numel()]).all() for r in r0_split[1:])) |
| 13328 | + r1_split = r1.split(7) |
| 13329 | + assert all(((r == r1_split[0][: r.numel()]).all() for r in r1_split[1:])) |
| 13330 | + r2_split = r2.split(8) |
| 13331 | + assert all(((r == r2_split[0][: r.numel()]).all() for r in r2_split[1:])) |
| 13332 | + |
| 13333 | + def test_trans_serial_env_check(self): |
| 13334 | + torch.manual_seed(0) |
| 13335 | + base_env = SerialEnv( |
| 13336 | + 3, |
| 13337 | + [partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)], |
| 13338 | + batch_locked=False, |
| 13339 | + ) |
| 13340 | + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) |
| 13341 | + policy_odd = self._create_policy_odd(base_env) |
| 13342 | + policy_even = self._create_policy_even(base_env) |
| 13343 | + transforms = self._create_transforms(condition, policy_even) |
| 13344 | + env = base_env.append_transform(transforms) |
| 13345 | + self._test_env(env, policy_odd) |
| 13346 | + |
| 13347 | + def test_trans_parallel_env_check(self): |
| 13348 | + torch.manual_seed(0) |
| 13349 | + base_env = ParallelEnv( |
| 13350 | + 3, |
| 13351 | + [partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)], |
| 13352 | + batch_locked=False, |
| 13353 | + mp_start_method=mp_ctx, |
| 13354 | + ) |
| 13355 | + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) |
| 13356 | + policy_odd = self._create_policy_odd(base_env) |
| 13357 | + policy_even = self._create_policy_even(base_env) |
| 13358 | + transforms = self._create_transforms(condition, policy_even) |
| 13359 | + env = base_env.append_transform(transforms) |
| 13360 | + self._test_env(env, policy_odd) |
| 13361 | + |
| 13362 | + def test_serial_trans_env_check(self): |
| 13363 | + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) |
| 13364 | + policy_odd = self._create_policy_odd(CountingEnv()) |
| 13365 | + |
| 13366 | + def make_env(max_count): |
| 13367 | + return partial(self._make_env, max_count, CountingEnv) |
| 13368 | + |
| 13369 | + env = SerialEnv(3, [make_env(6), make_env(7), make_env(8)]) |
| 13370 | + self._test_env(env, policy_odd) |
| 13371 | + |
| 13372 | + def test_parallel_trans_env_check(self): |
| 13373 | + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) |
| 13374 | + policy_odd = self._create_policy_odd(CountingEnv()) |
| 13375 | + |
| 13376 | + def make_env(max_count): |
| 13377 | + return partial(self._make_env, max_count, CountingEnv) |
| 13378 | + |
| 13379 | + env = ParallelEnv( |
| 13380 | + 3, [make_env(6), make_env(7), make_env(8)], mp_start_method=mp_ctx |
| 13381 | + ) |
| 13382 | + self._test_env(env, policy_odd) |
| 13383 | + |
| 13384 | + def test_transform_no_env(self): |
| 13385 | + policy_odd = lambda td: td |
| 13386 | + policy_even = lambda td: td |
| 13387 | + condition = lambda td: True |
| 13388 | + transforms = ConditionalPolicySwitch(condition=condition, policy=policy_even) |
| 13389 | + with pytest.raises( |
| 13390 | + RuntimeError, |
| 13391 | + match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.", |
| 13392 | + ): |
| 13393 | + transforms(TensorDict()) |
| 13394 | + |
| 13395 | + def test_transform_compose(self): |
| 13396 | + policy_odd = lambda td: td |
| 13397 | + policy_even = lambda td: td |
| 13398 | + condition = lambda td: True |
| 13399 | + transforms = Compose( |
| 13400 | + ConditionalPolicySwitch(condition=condition, policy=policy_even), |
| 13401 | + ) |
| 13402 | + with pytest.raises( |
| 13403 | + RuntimeError, |
| 13404 | + match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.", |
| 13405 | + ): |
| 13406 | + transforms(TensorDict()) |
| 13407 | + |
| 13408 | + def test_transform_env(self): |
| 13409 | + base_env = CountingEnv(max_steps=15) |
| 13410 | + condition = lambda td: ((td.get("step_count") % 2) == 0).all() |
| 13411 | + # Player 0 |
| 13412 | + policy_odd = lambda td: td.set("action", env.action_spec.zero()) |
| 13413 | + policy_even = lambda td: td.set("action", env.action_spec.one()) |
| 13414 | + transforms = Compose( |
| 13415 | + StepCounter(), |
| 13416 | + ConditionalPolicySwitch(condition=condition, policy=policy_even), |
| 13417 | + ) |
| 13418 | + env = base_env.append_transform(transforms) |
| 13419 | + env.check_env_specs() |
| 13420 | + r = env.rollout(1000, policy_odd, break_when_all_done=True) |
| 13421 | + assert r.shape[0] == 15 |
| 13422 | + assert (r["action"] == 0).all() |
| 13423 | + assert ( |
| 13424 | + r["step_count"] == torch.arange(1, r.numel() * 2, 2).unsqueeze(-1) |
| 13425 | + ).all() |
| 13426 | + assert r["next", "done"].any() |
| 13427 | + |
| 13428 | + # Player 1 |
| 13429 | + condition = lambda td: ((td.get("step_count") % 2) == 1).all() |
| 13430 | + transforms = Compose( |
| 13431 | + StepCounter(), |
| 13432 | + ConditionalPolicySwitch(condition=condition, policy=policy_odd), |
| 13433 | + ) |
| 13434 | + env = base_env.append_transform(transforms) |
| 13435 | + r = env.rollout(1000, policy_even, break_when_all_done=True) |
| 13436 | + assert r.shape[0] == 16 |
| 13437 | + assert (r["action"] == 1).all() |
| 13438 | + assert ( |
| 13439 | + r["step_count"] == torch.arange(0, r.numel() * 2, 2).unsqueeze(-1) |
| 13440 | + ).all() |
| 13441 | + assert r["next", "done"].any() |
| 13442 | + |
| 13443 | + def test_transform_model(self): |
| 13444 | + policy_odd = lambda td: td |
| 13445 | + policy_even = lambda td: td |
| 13446 | + condition = lambda td: True |
| 13447 | + transforms = nn.Sequential( |
| 13448 | + ConditionalPolicySwitch(condition=condition, policy=policy_even), |
| 13449 | + ) |
| 13450 | + with pytest.raises( |
| 13451 | + RuntimeError, |
| 13452 | + match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.", |
| 13453 | + ): |
| 13454 | + transforms(TensorDict()) |
| 13455 | + |
| 13456 | + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) |
| 13457 | + def test_transform_rb(self, rbclass): |
| 13458 | + policy_odd = lambda td: td |
| 13459 | + policy_even = lambda td: td |
| 13460 | + condition = lambda td: True |
| 13461 | + rb = rbclass(storage=LazyTensorStorage(10)) |
| 13462 | + rb.append_transform( |
| 13463 | + ConditionalPolicySwitch(condition=condition, policy=policy_even) |
| 13464 | + ) |
| 13465 | + rb.extend(TensorDict(batch_size=[2])) |
| 13466 | + with pytest.raises( |
| 13467 | + RuntimeError, |
| 13468 | + match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.", |
| 13469 | + ): |
| 13470 | + rb.sample(2) |
| 13471 | + |
| 13472 | + def test_transform_inverse(self): |
| 13473 | + return |
| 13474 | + |
| 13475 | + |
13274 | 13476 | if __name__ == "__main__":
|
13275 | 13477 | args, unknown = argparse.ArgumentParser().parse_known_args()
|
13276 | 13478 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
|
0 commit comments