Skip to content

Commit ed54f36

Browse files
committed
factor stuff out
1 parent 657288a commit ed54f36

File tree

1 file changed

+95
-60
lines changed

1 file changed

+95
-60
lines changed

src/rpad/rlbench_utils/placement_dataset.py

+95-60
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,72 @@ class AnchorMode(str, Enum):
191191
SINGLE_OBJECT = "single_object"
192192

193193

194+
def get_anchor_points(
195+
anchor_mode: AnchorMode,
196+
rgb,
197+
point_cloud,
198+
mask,
199+
task_name,
200+
phase,
201+
use_from_simulator=False,
202+
handle_mapping=None,
203+
names_to_handles=None,
204+
):
205+
if anchor_mode == AnchorMode.RAW:
206+
return rgb, point_cloud
207+
elif anchor_mode == AnchorMode.BACKGROUND_REMOVED:
208+
return filter_out_names(
209+
rgb, point_cloud, mask, handle_mapping, BACKGROUND_NAMES
210+
)
211+
elif anchor_mode == AnchorMode.BACKGROUND_ROBOT_REMOVED:
212+
return filter_out_names(
213+
rgb,
214+
point_cloud,
215+
mask,
216+
handle_mapping,
217+
BACKGROUND_NAMES + ROBOT_NONGRIPPER_NAMES,
218+
)
219+
elif anchor_mode == AnchorMode.SINGLE_OBJECT:
220+
if use_from_simulator:
221+
return get_rgb_point_cloud_by_object_names(
222+
rgb,
223+
point_cloud,
224+
mask,
225+
TASK_DICT[task_name]["phase"][phase]["anchor_obj_names"],
226+
)
227+
else:
228+
return get_rgb_point_cloud_by_object_handles(
229+
rgb,
230+
point_cloud,
231+
mask,
232+
names_to_handles[phase]["anchor_obj_names"],
233+
)
234+
else:
235+
raise ValueError("Anchor mode must be one of the AnchorMode enum values.")
236+
237+
238+
def get_action_points(
239+
action_mode: ActionMode,
240+
rgb,
241+
point_cloud,
242+
mask,
243+
action_handles,
244+
gripper_handles,
245+
):
246+
if action_mode == ActionMode.GRIPPER_AND_OBJECT:
247+
action_handles = action_handles + gripper_handles
248+
elif action_mode == ActionMode.OBJECT:
249+
pass
250+
else:
251+
raise ValueError("Action mode must be one of the ActionMode enum values.")
252+
253+
action_rgb, action_point_cloud = get_rgb_point_cloud_by_object_handles(
254+
rgb, point_cloud, mask, action_handles
255+
)
256+
257+
return action_rgb, action_point_cloud
258+
259+
194260
class RLBenchPlacementDataset(data.Dataset):
195261
def __init__(
196262
self,
@@ -299,7 +365,7 @@ def _load_keyframes(
299365

300366
keyframes = [demo[ix] for ix in keyframe_ixs]
301367

302-
return keyframes, demo[0]
368+
return keyframes, demo[0] # type: ignore
303369

304370
# We also cache in memory, since all the transformations are the same.
305371
# Saves a lot of time when loading the dataset, but don't have to worry
@@ -347,69 +413,38 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
347413
# Find the first grasp instance
348414
key_obs = keyframes[phase_ix]
349415

350-
if self.debugging:
351-
raise ValueError("Debugging not implemented.")
352-
return {
353-
"keyframes": keyframe_ixs,
354-
"demo": demo,
355-
"initial_obs": initial_obs,
356-
"key_obs": key_obs,
357-
"init_front_rgb": torch.from_numpy(initial_obs.front_rgb),
358-
"key_front_rgb": torch.from_numpy(key_obs.front_rgb),
359-
"init_front_mask": torch.from_numpy(
360-
initial_obs.front_mask.astype(np.int32)
361-
),
362-
"key_front_mask": torch.from_numpy(key_obs.front_mask.astype(np.int32)),
363-
"phase": phase,
364-
"phase_onehot": torch.from_numpy(phase_onehot),
365-
}
416+
action_handles = self.names_to_handles[phase]["action_obj_names"]
417+
418+
def _select_action_vals(rgb, point_cloud, mask):
419+
return get_action_points(
420+
self.action_mode,
421+
rgb,
422+
point_cloud,
423+
mask,
424+
action_handles,
425+
self.gripper_handles,
426+
)
427+
428+
def _select_anchor_vals(rgb, point_cloud, mask):
429+
return get_anchor_points(
430+
self.anchor_mode,
431+
rgb,
432+
point_cloud,
433+
mask,
434+
self.task_name,
435+
phase,
436+
use_from_simulator=False,
437+
handle_mapping=self.handle_mapping,
438+
names_to_handles=self.names_to_handles,
439+
)
366440

367441
# Merge all the initial point clouds and masks into one.
368442
init_rgb, init_point_cloud, init_mask = obs_to_rgb_point_cloud(initial_obs)
369443

370-
action_handles = self.names_to_handles[phase]["action_obj_names"]
371-
if self.action_mode == ActionMode.GRIPPER_AND_OBJECT:
372-
action_handles = action_handles + self.gripper_handles
373-
elif self.action_mode == ActionMode.OBJECT:
374-
pass
375-
else:
376-
raise ValueError("Action mode must be one of the ActionMode enum values.")
377-
378-
# Split the initial point cloud and rgb into action and anchor.
379-
(
380-
init_action_rgb,
381-
init_action_point_cloud,
382-
) = get_rgb_point_cloud_by_object_handles(
383-
init_rgb, init_point_cloud, init_mask, action_handles
444+
init_action_rgb, init_action_point_cloud = _select_action_vals(
445+
init_rgb, init_point_cloud, init_mask
384446
)
385447

386-
def _select_anchor_vals(rgb, point_cloud, mask):
387-
if self.anchor_mode == AnchorMode.RAW:
388-
return rgb, point_cloud
389-
elif self.anchor_mode == AnchorMode.BACKGROUND_REMOVED:
390-
return filter_out_names(
391-
rgb, point_cloud, mask, self.handle_mapping, BACKGROUND_NAMES
392-
)
393-
elif self.anchor_mode == AnchorMode.BACKGROUND_ROBOT_REMOVED:
394-
return filter_out_names(
395-
rgb,
396-
point_cloud,
397-
mask,
398-
self.handle_mapping,
399-
BACKGROUND_NAMES + ROBOT_NONGRIPPER_NAMES,
400-
)
401-
elif self.anchor_mode == AnchorMode.SINGLE_OBJECT:
402-
return get_rgb_point_cloud_by_object_handles(
403-
rgb,
404-
point_cloud,
405-
mask,
406-
self.names_to_handles[phase]["anchor_obj_names"],
407-
)
408-
else:
409-
raise ValueError(
410-
"Anchor mode must be one of the AnchorMode enum values."
411-
)
412-
413448
init_anchor_rgb, init_anchor_point_cloud = _select_anchor_vals(
414449
init_rgb, init_point_cloud, init_mask
415450
)
@@ -418,8 +453,8 @@ def _select_anchor_vals(rgb, point_cloud, mask):
418453
key_rgb, key_point_cloud, key_mask = obs_to_rgb_point_cloud(key_obs)
419454

420455
# Split the key point cloud and rgb into action and anchor.
421-
key_action_rgb, key_action_point_cloud = get_rgb_point_cloud_by_object_handles(
422-
key_rgb, key_point_cloud, key_mask, action_handles
456+
key_action_rgb, key_action_point_cloud = _select_action_vals(
457+
key_rgb, key_point_cloud, key_mask
423458
)
424459
key_anchor_rgb, key_anchor_point_cloud = _select_anchor_vals(
425460
key_rgb, key_point_cloud, key_mask

0 commit comments

Comments
 (0)