Skip to content

Commit da1460c

Browse files
committed
add SPiRL closed-loop model for block stacking, update READMEs
1 parent cf3b90f commit da1460c

File tree

8 files changed

+305
-4
lines changed

8 files changed

+305
-4
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ This is the official PyTorch implementation of the paper "**Accelerating Reinfor
1616
(CoRL 2020).
1717

1818
## Updates
19+
- **[Apr 2021]**: extended improved SPiRL version to support image-based observations
20+
(see [example commands](spirl/configs/skill_prior_learning/block_stacking/hierarchical_cl/README.md))
1921
- **[Mar 2021]**: added an improved version of SPiRL with closed-loop skill decoder
2022
(see [example commands](spirl/configs/skill_prior_learning/kitchen/hierarchical_cl/README.md))
2123

Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import os
2+
import copy
3+
4+
from spirl.utils.general_utils import AttrDict
5+
from spirl.rl.components.agent import FixedIntervalHierarchicalAgent
6+
from spirl.rl.components.critic import SplitObsMLPCritic
7+
from spirl.rl.components.sampler import ACMultiImageAugmentedHierarchicalSampler
8+
from spirl.rl.components.replay_buffer import UniformReplayBuffer
9+
from spirl.rl.policies.prior_policies import ACLearnedPriorAugmentedPIPolicy
10+
from spirl.rl.envs.block_stacking import HighStack11StackEnvV0, SparseHighStack11StackEnvV0
11+
from spirl.rl.agents.ac_agent import SACAgent
12+
from spirl.rl.agents.prior_sac_agent import ActionPriorSACAgent
13+
from spirl.rl.policies.cl_model_policies import ACClModelPolicy
14+
from spirl.models.closed_loop_spirl_mdl import ImageClSPiRLMdl
15+
from spirl.configs.default_data_configs.block_stacking import data_spec
16+
17+
18+
current_dir = os.path.dirname(os.path.realpath(__file__))
19+
20+
notes = 'used to test the RL implementation'
21+
22+
configuration = {
23+
'seed': 42,
24+
'agent': FixedIntervalHierarchicalAgent,
25+
'environment': SparseHighStack11StackEnvV0,
26+
'sampler': ACMultiImageAugmentedHierarchicalSampler,
27+
'data_dir': '.',
28+
'num_epochs': 100,
29+
'max_rollout_len': 1000,
30+
'n_steps_per_epoch': 1e5,
31+
'n_warmup_steps': 5e3,
32+
}
33+
configuration = AttrDict(configuration)
34+
35+
36+
# Replay Buffer
37+
replay_params = AttrDict(
38+
capacity=1e5,
39+
dump_replay=False,
40+
)
41+
42+
# Observation Normalization
43+
obs_norm_params = AttrDict(
44+
)
45+
46+
sampler_config = AttrDict(
47+
n_frames=2,
48+
)
49+
50+
base_agent_params = AttrDict(
51+
batch_size=256,
52+
replay=UniformReplayBuffer,
53+
replay_params=replay_params,
54+
clip_q_target=False,
55+
)
56+
57+
58+
###### Low-Level ######
59+
# LL Policy Model
60+
ll_model_params = AttrDict(
61+
state_dim=data_spec.state_dim,
62+
action_dim=data_spec.n_actions,
63+
n_rollout_steps=10,
64+
kl_div_weight=1e-2,
65+
prior_input_res=data_spec.res,
66+
n_input_frames=2,
67+
cond_decode=True,
68+
)
69+
70+
# LL Policy
71+
ll_policy_params = AttrDict(
72+
policy_model=ImageClSPiRLMdl,
73+
policy_model_params=ll_model_params,
74+
policy_model_checkpoint=os.path.join(os.environ["EXP_DIR"], "skill_learning/block_stacking/hierarchical_cl"),
75+
initial_log_sigma=-50.,
76+
)
77+
ll_policy_params.update(ll_model_params)
78+
79+
# LL Critic
80+
ll_critic_params = AttrDict(
81+
action_dim=data_spec.n_actions,
82+
input_dim=data_spec.state_dim,
83+
output_dim=1,
84+
action_input=True,
85+
unused_obs_size=10, # ignore HL policy z output in observation for LL critic
86+
)
87+
88+
# LL Agent
89+
ll_agent_config = copy.deepcopy(base_agent_params)
90+
ll_agent_config.update(AttrDict(
91+
policy=ACClModelPolicy,
92+
policy_params=ll_policy_params,
93+
critic=SplitObsMLPCritic,
94+
critic_params=ll_critic_params,
95+
))
96+
97+
98+
###### High-Level ########
99+
# HL Policy
100+
hl_policy_params = AttrDict(
101+
action_dim=10, # z-dimension of the skill VAE
102+
input_dim=data_spec.state_dim,
103+
max_action_range=2., # prior is Gaussian with unit variance
104+
prior_model=ll_policy_params.policy_model,
105+
prior_model_params=ll_policy_params.policy_model_params,
106+
prior_model_checkpoint=ll_policy_params.policy_model_checkpoint,
107+
)
108+
109+
# HL Critic
110+
hl_critic_params = AttrDict(
111+
action_dim=hl_policy_params.action_dim,
112+
input_dim=hl_policy_params.input_dim,
113+
output_dim=1,
114+
n_layers=2, # number of policy network layers
115+
nz_mid=256,
116+
action_input=True,
117+
unused_obs_size=ll_model_params.prior_input_res **2 * 3 * ll_model_params.n_input_frames,
118+
)
119+
120+
# HL Agent
121+
hl_agent_config = copy.deepcopy(base_agent_params)
122+
hl_agent_config.update(AttrDict(
123+
policy=ACLearnedPriorAugmentedPIPolicy,
124+
policy_params=hl_policy_params,
125+
critic=SplitObsMLPCritic,
126+
critic_params=hl_critic_params,
127+
td_schedule_params=AttrDict(p=5.),
128+
))
129+
130+
131+
##### Joint Agent #######
132+
agent_config = AttrDict(
133+
hl_agent=ActionPriorSACAgent,
134+
hl_agent_params=hl_agent_config,
135+
ll_agent=SACAgent,
136+
ll_agent_params=ll_agent_config,
137+
hl_interval=ll_model_params.n_rollout_steps,
138+
log_videos=True,
139+
update_hl=True,
140+
update_ll=False,
141+
)
142+
143+
# Dataset - Random data
144+
data_config = AttrDict()
145+
data_config.dataset_spec = data_spec
146+
147+
# Environment
148+
env_config = AttrDict(
149+
name="block_stacking",
150+
reward_norm=1.,
151+
screen_width=data_spec.res,
152+
screen_height=data_spec.res,
153+
env_config=AttrDict(camera_name='agentview',
154+
screen_width=data_spec.res,
155+
screen_height=data_spec.res,)
156+
)
157+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Image-based SPiRL w/ Closed-Loop Skill Decoder
2+
3+
This version of the SPiRL model uses a [closed-loop action decoder](../../../../models/closed_loop_spirl_mdl.py#L55):
4+
in contrast to the original SPiRL model it takes the current environment observation as input in every skill decoding step.
5+
6+
This image-based model is a direct extension of the
7+
[state-based SPiRL model with closed-loop skill decoder](../../kitchen/hierarchical_cl/README.md).
8+
Similar to the state-based model we find that the image-based closed-loop model improves performance over the original
9+
image-based SPiRL model, particularly in tasks that require precise control.
10+
We evaluate it on a more challenging, sparse reward version of the block stacking environment
11+
where the agent is rewarded for the height of the tower it built, but does not receive any rewards for picking or lifting
12+
blocks. We find that on this challenging environment, the closed-loop skill decoder ("SPiRLv2") outperforms the original
13+
SPiRL model with open-loop skill decoder ("SPiRLv1").
14+
15+
<p align="center">
16+
<img src="../../../../../docs/resources/block_stacking_sparse_results.png" width="400">
17+
</p>
18+
</img>
19+
20+
We also tried the closed-loop model on the image-based maze navigation task, but did not find it to improve performance,
21+
which we attribute to the easier control task that does not require closed-loop control.
22+
23+
## Example Commands
24+
25+
To train the image-based SPiRL model with closed-loop action decoder on the block stacking environment, run the following command:
26+
```
27+
python3 spirl/train.py --path=spirl/configs/skill_prior_learning/block_stacking/hierarchical_cl --val_data_size=160
28+
```
29+
30+
To train a downstream task policy with RL using the closed-loop image-based SPiRL model
31+
on the sparse reward block stacking environment, run the following command:
32+
```
33+
python3 spirl/rl/train.py --path=spirl/configs/hrl/block_stacking/spirl_cl --seed=0 --prefix=SPIRLv2_block_stacking_seed0
34+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import os
2+
3+
from spirl.models.skill_prior_mdl import SkillSpaceLogger
4+
from spirl.models.closed_loop_spirl_mdl import ImageClSPiRLMdl
5+
from spirl.utils.general_utils import AttrDict
6+
from spirl.configs.default_data_configs.block_stacking import data_spec
7+
from spirl.components.evaluator import TopOfNSequenceEvaluator
8+
9+
10+
current_dir = os.path.dirname(os.path.realpath(__file__))
11+
12+
13+
configuration = {
14+
'model': ImageClSPiRLMdl,
15+
'logger': SkillSpaceLogger,
16+
'data_dir': os.path.join(os.environ['DATA_DIR'], 'block_stacking'),
17+
'epoch_cycles_train': 10,
18+
'evaluator': TopOfNSequenceEvaluator,
19+
'top_of_n_eval': 100,
20+
'top_comp_metric': 'mse',
21+
}
22+
configuration = AttrDict(configuration)
23+
24+
model_config = AttrDict(
25+
state_dim=data_spec.state_dim,
26+
action_dim=data_spec.n_actions,
27+
n_rollout_steps=10,
28+
kl_div_weight=1e-3,
29+
prior_input_res=data_spec.res,
30+
n_input_frames=2,
31+
cond_decode=True,
32+
)
33+
34+
# Dataset
35+
data_config = AttrDict()
36+
data_config.dataset_spec = data_spec
37+
data_config.dataset_spec.subseq_len = model_config.n_rollout_steps + model_config.n_input_frames

spirl/configs/skill_prior_learning/kitchen/hierarchical_cl/README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@ SPiRL model, particularly on tasks that require precise control, like in the kit
1111
</p>
1212
</img>
1313

14+
For an implementation of the closed-loop SPiRL model that supports image observations,
15+
see [here](../../block_stacking/hierarchical_cl/README.md).
16+
1417
## Example Commands
1518

1619
To train the SPiRL model with closed-loop action decoder on the kitchen environment, run the following command:
1720
```
1821
python3 spirl/train.py --path=spirl/configs/skill_prior_learning/kitchen/hierarchical_cl --val_data_size=160
1922
```
20-
Our current implementation of the closed-loop SPiRL model only supports state-based inputs, but an extension to
21-
image observations is straightforward analogous to how we adapted the
22-
original SPiRL model for [image inputs](../../../../models/skill_prior_mdl.py#L321).
2323

2424
To train a downstream task policy with RL using the closed-loop SPiRL model we just trained, run the following command:
2525
```

spirl/data/block_stacking/src/block_stacking_env.py

+62
Original file line numberDiff line numberDiff line change
@@ -963,3 +963,65 @@ def _get_reward(self):
963963
def _has_support(self, block, others):
964964
return not block.lifted or any([block.stacked_on_loose(b) and self._has_support(b, [bb for bb in others if b.name != bb.name])
965965
for b in others])
966+
967+
968+
class SparseHighStackBlockStackEnv(NoOrderBlockStackEnv):
969+
"""Simple reward function that just rewards the highest stacked tower."""
970+
REWARD_SCALE = 1.0
971+
972+
def _reset_internal(self, keep_sim_object=False):
973+
super()._reset_internal(keep_sim_object)
974+
self._final_height = 0.
975+
976+
def get_episode_info(self):
977+
ep_info = super().get_episode_info()
978+
ep_info.final_height = self._final_height
979+
return ep_info
980+
981+
def _get_reward(self):
982+
"""Compute reward for stacking blocks without order."""
983+
rew_dict = AttrDict()
984+
985+
max_height = 0.
986+
heights, supported_heights = np.zeros(len(self._blocks)), np.zeros(len(self._blocks))
987+
for i, block in enumerate(self._blocks):
988+
height = block.dist_lifted
989+
heights[i] = height
990+
991+
# set flags
992+
if not self._grasped_flag[i]:
993+
self._grasped_flag[i] = block.grasped(self.gripper_pos, self.gripper_finger_dist,
994+
self.gripper_finger_poses)
995+
if not self._lifted_flag[i]:
996+
self._lifted_flag[i] = (not self._hp.restrict_grasped or self._grasped_flag[i]) and \
997+
(not self._hp.restrict_upright or block.upright) and block.lifted
998+
if not self._delivered_flag[i]:
999+
self._delivered_flag[i] = (not self._hp.restrict_grasped or self._grasped_flag[i]) \
1000+
and (not self._hp.restrict_upright or block.upright) \
1001+
and any([block.above(b) for b in self._blocks if b.name != block.name])
1002+
1003+
# compute reward
1004+
if (not self._hp.restrict_grasped or self._grasped_flag[i]) and \
1005+
(not self._hp.restrict_upright or block.upright) and \
1006+
self._has_support(block, [b for b in self._blocks if block.name != b.name]):
1007+
self._stacked_flag[i] = True
1008+
supported_heights[i] = height
1009+
if height > max_height:
1010+
max_height = height
1011+
self._final_height = max_height / (2 * self._hp.block_size)
1012+
1013+
total_rew = max_height * self.REWARD_SCALE
1014+
1015+
rew_dict["heights"] = heights.round(3)
1016+
rew_dict["sup_heights"] = supported_heights.round(3)
1017+
rew_dict["rew_total"] = np.array(total_rew).round(3)
1018+
rew_dict["max_height"] = np.array(self._final_height).round(3)
1019+
1020+
self._prev_block_pos = [copy.deepcopy(b.pos) for b in self._blocks] # update for next round of reward comp
1021+
self._prev_gripper_pos = copy.deepcopy(self.gripper_pos)
1022+
1023+
return rew_dict
1024+
1025+
def _has_support(self, block, others):
1026+
return not block.lifted or any([block.stacked_on_loose(b) and self._has_support(b, [bb for bb in others if b.name != bb.name])
1027+
for b in others])

spirl/rl/envs/block_stacking.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from spirl.rl.components.environment import GymEnv
44
from spirl.utils.general_utils import AttrDict, ParamDict
55
from spirl.data.block_stacking.src.block_stacking_env import BlockStackEnv as UnwrappedBlockStackEnv
6-
from spirl.data.block_stacking.src.block_stacking_env import NoOrderBlockStackEnv, HighStackBlockStackEnv
6+
from spirl.data.block_stacking.src.block_stacking_env import HighStackBlockStackEnv, SparseHighStackBlockStackEnv
77
from spirl.data.block_stacking.src.block_task_generator import FixedSizeSingleTowerBlockTaskGenerator
88

99

@@ -92,3 +92,12 @@ def _get_default_env_config(self):
9292
default_env_config.table_size = (1.2, 2.4, 0.8)
9393
default_env_config.n_blocks = 11
9494
return default_env_config
95+
96+
97+
class SparseHighStack11StackEnvV0(HighStack11StackEnvV0):
98+
def _make_env(self, name):
99+
default_env_config = self._get_default_env_config()
100+
if self._hp.env_config is not None:
101+
default_env_config.update(self._hp.env_config)
102+
103+
return SparseHighStackBlockStackEnv(default_env_config)

0 commit comments

Comments
 (0)