|
8 | 8 |
|
9 | 9 | b3d.rr_init("acquire_object_model")
|
10 | 10 |
|
| 11 | +# python scripts/acquire_object_model.py assets/shared_data_bucket/input_data/lysol_static.r3d |
11 | 12 |
|
12 |
| -parser = argparse.ArgumentParser("acquire_object_mode") |
13 |
| -parser.add_argument("input", help="r3d file", type=str) |
14 |
| -args = parser.parse_args() |
15 |
| - |
16 |
| -filename = args.input |
17 |
| -data = b3d.io.load_r3d(filename) |
18 |
| - |
19 |
| - |
20 |
| -_, _, fx, fy, cx, cy, near, far = data["camera_intrinsics_depth"] |
21 |
| -image_height, image_width = data["depth"].shape[1:3] |
22 |
| -num_scenes = data["depth"].shape[0] |
23 |
| - |
24 |
| -indices = jnp.arange(0, num_scenes, 10) |
25 |
| - |
26 |
| -camera_poses_full = data["camera_pose"] |
27 |
| -camera_poses = camera_poses_full[indices] |
28 |
| - |
29 |
| -xyz = b3d.xyz_from_depth_vectorized(data["depth"][indices], fx, fy, cx, cy) |
30 |
| -xyz_world_frame = camera_poses[:, None, None].apply(xyz) |
31 |
| - |
32 |
| -# for i in range(len(xyz_world_frame)): |
33 |
| -# b3d.rr_set_time(i) |
34 |
| -# b3d.utils.rr_log_cloud("xyz", xyz_world_frame[i]) |
35 |
| - |
36 |
| -# Resize rgbs to be same size as depth. |
37 |
| -rgbs = data["rgb"] |
38 |
| -rgbs_resized = jnp.clip( |
39 |
| - jax.vmap(jax.image.resize, in_axes=(0, None, None))( |
40 |
| - rgbs[indices] / 255.0, |
41 |
| - (image_height, image_width, 3), |
42 |
| - "linear", |
43 |
| - ), |
44 |
| - 0.0, |
45 |
| - 1.0, |
46 |
| -) |
47 |
| - |
48 |
| - |
49 |
| -masks = [b3d.carvekit_get_foreground_mask(r) for r in rgbs_resized] |
50 |
| -masks_concat = jnp.stack(masks, axis=0) |
51 |
| - |
52 |
| -grid_center = jnp.median(camera_poses[0].apply(xyz[0][masks[0]]), axis=0) |
53 |
| -W = 0.3 |
54 |
| -D = 100 |
55 |
| -grid = jnp.stack( |
56 |
| - jnp.meshgrid( |
57 |
| - jnp.linspace(grid_center[0] - W / 2, grid_center[0] + W / 2, D), |
58 |
| - jnp.linspace(grid_center[1] - W / 2, grid_center[1] + W / 2, D), |
59 |
| - jnp.linspace(grid_center[2] - W / 2, grid_center[2] + W / 2, D), |
60 |
| - ), |
61 |
| - axis=-1, |
62 |
| -).reshape(-1, 3) |
63 |
| - |
64 |
| -occ_free_occl_, colors_per_voxel_ = ( |
65 |
| - b3d.voxel_occupied_occluded_free_parallel_camera_depth( |
66 |
| - camera_poses, |
67 |
| - rgbs_resized, |
68 |
| - xyz[..., 2] * masks_concat + (1.0 - masks_concat) * 5.0, |
69 |
| - grid, |
70 |
| - fx, |
71 |
| - fy, |
72 |
| - cx, |
73 |
| - cy, |
74 |
| - 6.0, |
75 |
| - 0.005, |
| 13 | + |
| 14 | +# ssh sam-b3d-l4.us-west1-a.probcomp-caliban -L 5000:localhost:5000 |
| 15 | + |
| 16 | + |
| 17 | +def acquire(input_path, output_path=None): |
| 18 | + if output_path is None: |
| 19 | + output_path = input_path + ".graphics_edits.mp4" |
| 20 | + |
| 21 | + data = b3d.io.load_r3d(input_path) |
| 22 | + |
| 23 | + _, _, fx, fy, cx, cy, near, far = data["camera_intrinsics_depth"] |
| 24 | + image_height, image_width = data["depth"].shape[1:3] |
| 25 | + num_scenes = data["depth"].shape[0] |
| 26 | + |
| 27 | + indices = jnp.arange(0, num_scenes, 10) |
| 28 | + |
| 29 | + camera_poses_full = data["camera_pose"] |
| 30 | + camera_poses = camera_poses_full[indices] |
| 31 | + |
| 32 | + xyz = b3d.xyz_from_depth_vectorized(data["depth"][indices], fx, fy, cx, cy) |
| 33 | + xyz_world_frame = camera_poses[:, None, None].apply(xyz) |
| 34 | + |
| 35 | + # for i in range(len(xyz_world_frame)): |
| 36 | + # b3d.rr_set_time(i) |
| 37 | + # b3d.utils.rr_log_cloud("xyz", xyz_world_frame[i]) |
| 38 | + |
| 39 | + # Resize rgbs to be same size as depth. |
| 40 | + rgbs = data["rgb"] |
| 41 | + rgbs_resized = jnp.clip( |
| 42 | + jax.vmap(jax.image.resize, in_axes=(0, None, None))( |
| 43 | + rgbs[indices] / 255.0, |
| 44 | + (image_height, image_width, 3), |
| 45 | + "linear", |
| 46 | + ), |
| 47 | + 0.0, |
| 48 | + 1.0, |
| 49 | + ) |
| 50 | + |
| 51 | + masks = [b3d.carvekit_get_foreground_mask(r) for r in rgbs_resized] |
| 52 | + masks_concat = jnp.stack(masks, axis=0) |
| 53 | + |
| 54 | + grid_center = jnp.median(camera_poses[0].apply(xyz[0][masks[0]]), axis=0) |
| 55 | + W = 0.3 |
| 56 | + D = 100 |
| 57 | + grid = jnp.stack( |
| 58 | + jnp.meshgrid( |
| 59 | + jnp.linspace(grid_center[0] - W / 2, grid_center[0] + W / 2, D), |
| 60 | + jnp.linspace(grid_center[1] - W / 2, grid_center[1] + W / 2, D), |
| 61 | + jnp.linspace(grid_center[2] - W / 2, grid_center[2] + W / 2, D), |
| 62 | + ), |
| 63 | + axis=-1, |
| 64 | + ).reshape(-1, 3) |
| 65 | + |
| 66 | + occ_free_occl_, colors_per_voxel_ = ( |
| 67 | + b3d.voxel_occupied_occluded_free_parallel_camera_depth( |
| 68 | + camera_poses, |
| 69 | + rgbs_resized, |
| 70 | + xyz[..., 2] * masks_concat + (1.0 - masks_concat) * 5.0, |
| 71 | + grid, |
| 72 | + fx, |
| 73 | + fy, |
| 74 | + cx, |
| 75 | + cy, |
| 76 | + 6.0, |
| 77 | + 0.005, |
| 78 | + ) |
| 79 | + ) |
| 80 | + i = len(occ_free_occl_) |
| 81 | + occ_free_occl, colors_per_voxel = occ_free_occl_[:i], colors_per_voxel_[:i] |
| 82 | + total_occ = (occ_free_occl == 1.0).sum(0) |
| 83 | + total_free = (occ_free_occl == -1.0).sum(0) |
| 84 | + ratio = total_occ / (total_occ + total_free) * ((total_occ + total_free) > 1) |
| 85 | + |
| 86 | + grid_colors = colors_per_voxel.sum(0) / (total_occ[..., None]) |
| 87 | + model_mask = ratio > 0.2 |
| 88 | + |
| 89 | + resolution = 0.0015 |
| 90 | + |
| 91 | + grid_points = grid[model_mask] |
| 92 | + colors = grid_colors[model_mask] |
| 93 | + |
| 94 | + meshes = b3d.mesh.transform_mesh( |
| 95 | + jax.vmap(b3d.mesh.Mesh.cube_mesh)( |
| 96 | + jnp.ones((grid_points.shape[0], 3)) * resolution * 2.0, colors |
| 97 | + ), |
| 98 | + b3d.Pose.from_translation(grid_points)[:, None], |
| 99 | + ) |
| 100 | + _object_mesh = b3d.mesh.Mesh.squeeze_mesh(meshes) |
| 101 | + |
| 102 | + object_pose = Pose.from_translation(jnp.median(_object_mesh.vertices, axis=0)) |
| 103 | + object_mesh = _object_mesh.transform(object_pose.inv()) |
| 104 | + object_mesh.rr_visualize("mesh") |
| 105 | + |
| 106 | + mesh_filename = input_path + ".mesh.obj" |
| 107 | + # Save the mesh |
| 108 | + print(f"Saving obj file to {mesh_filename}") |
| 109 | + object_mesh.save(mesh_filename) |
| 110 | + |
| 111 | + renderer = b3d.RendererOriginal( |
| 112 | + image_width, image_height, fx, fy, cx, cy, near, far |
| 113 | + ) |
| 114 | + rgbds = renderer.render_rgbd_many( |
| 115 | + (camera_poses[:, None].inv() @ object_pose).apply(object_mesh.vertices), |
| 116 | + object_mesh.faces, |
| 117 | + jnp.tile(object_mesh.vertex_attributes, (len(camera_poses), 1, 1)), |
| 118 | + ) |
| 119 | + |
| 120 | + sub_indices = jnp.array([0, 5, len(camera_poses) - 15, len(camera_poses) - 5]) |
| 121 | + mask = rgbds[sub_indices, ..., 3] == 0.0 |
| 122 | + |
| 123 | + background_xyzs = xyz_world_frame[sub_indices][mask] |
| 124 | + colors = rgbs_resized[sub_indices][mask, :] |
| 125 | + distances_from_camera = xyz[sub_indices][..., 2][mask][..., None] / fx |
| 126 | + |
| 127 | + # subset = jax.random.choice(jax.random.PRNGKey(0), jnp.arange(background_xyzs.shape[0]), shape=(background_xyzs.shape[0]//3,), replace=False) |
| 128 | + |
| 129 | + # background_xyzs = background_xyzs[subset] |
| 130 | + # colors = colors[subset] |
| 131 | + # distances_from_camera = distances_from_camera[subset] |
| 132 | + |
| 133 | + meshes = b3d.mesh.transform_mesh( |
| 134 | + jax.vmap(b3d.mesh.Mesh.cube_mesh)( |
| 135 | + jnp.ones((background_xyzs.shape[0], 3)) * distances_from_camera, colors |
| 136 | + ), |
| 137 | + b3d.Pose.from_translation(background_xyzs)[:, None], |
76 | 138 | )
|
77 |
| -) |
78 |
| -i = len(occ_free_occl_) |
79 |
| -occ_free_occl, colors_per_voxel = occ_free_occl_[:i], colors_per_voxel_[:i] |
80 |
| -total_occ = (occ_free_occl == 1.0).sum(0) |
81 |
| -total_free = (occ_free_occl == -1.0).sum(0) |
82 |
| -ratio = total_occ / (total_occ + total_free) * ((total_occ + total_free) > 1) |
83 |
| - |
84 |
| -grid_colors = colors_per_voxel.sum(0) / (total_occ[..., None]) |
85 |
| -model_mask = ratio > 0.2 |
86 |
| - |
87 |
| -resolution = 0.0015 |
88 |
| - |
89 |
| -grid_points = grid[model_mask] |
90 |
| -colors = grid_colors[model_mask] |
91 |
| - |
92 |
| -meshes = b3d.mesh.transform_mesh( |
93 |
| - jax.vmap(b3d.mesh.Mesh.cube_mesh)( |
94 |
| - jnp.ones((grid_points.shape[0], 3)) * resolution * 2.0, colors |
95 |
| - ), |
96 |
| - b3d.Pose.from_translation(grid_points)[:, None], |
97 |
| -) |
98 |
| -_object_mesh = b3d.mesh.Mesh.squeeze_mesh(meshes) |
99 |
| - |
100 |
| -object_pose = Pose.from_translation(jnp.median(_object_mesh.vertices, axis=0)) |
101 |
| -object_mesh = _object_mesh.transform(object_pose.inv()) |
102 |
| -object_mesh.rr_visualize("mesh") |
103 |
| - |
104 |
| -mesh_filename = filename + ".mesh.obj" |
105 |
| -# Save the mesh |
106 |
| -print(f"Saving obj file to {mesh_filename}") |
107 |
| -object_mesh.save(mesh_filename) |
108 |
| - |
109 |
| -renderer = b3d.RendererOriginal(image_width, image_height, fx, fy, cx, cy, near, far) |
110 |
| -rgbds = renderer.render_rgbd_many( |
111 |
| - (camera_poses[:, None].inv() @ object_pose).apply(object_mesh.vertices), |
112 |
| - object_mesh.faces, |
113 |
| - jnp.tile(object_mesh.vertex_attributes, (len(camera_poses), 1, 1)), |
114 |
| -) |
115 |
| - |
116 |
| -sub_indices = jnp.array([0, 5, len(camera_poses) - 15, len(camera_poses) - 5]) |
117 |
| -mask = rgbds[sub_indices, ..., 3] == 0.0 |
118 |
| - |
119 |
| -background_xyzs = xyz_world_frame[sub_indices][mask] |
120 |
| -colors = rgbs_resized[sub_indices][mask, :] |
121 |
| -distances_from_camera = xyz[sub_indices][..., 2][mask][..., None] / fx |
122 |
| - |
123 |
| -# subset = jax.random.choice(jax.random.PRNGKey(0), jnp.arange(background_xyzs.shape[0]), shape=(background_xyzs.shape[0]//3,), replace=False) |
124 |
| - |
125 |
| -# background_xyzs = background_xyzs[subset] |
126 |
| -# colors = colors[subset] |
127 |
| -# distances_from_camera = distances_from_camera[subset] |
128 |
| - |
129 |
| -meshes = b3d.mesh.transform_mesh( |
130 |
| - jax.vmap(b3d.mesh.Mesh.cube_mesh)( |
131 |
| - jnp.ones((background_xyzs.shape[0], 3)) * distances_from_camera, colors |
132 |
| - ), |
133 |
| - b3d.Pose.from_translation(background_xyzs)[:, None], |
134 |
| -) |
135 |
| -background_mesh = b3d.mesh.Mesh.squeeze_mesh(meshes) |
136 |
| -background_mesh.rr_visualize("background_mesh") |
137 |
| - |
138 |
| - |
139 |
| -object_poses = [ |
140 |
| - object_pose, |
141 |
| - Pose.identity(), |
142 |
| - object_pose @ Pose.from_translation(jnp.array([-0.1, 0.0, 0.1])), |
143 |
| - object_pose @ Pose.from_translation(jnp.array([-0.1, 0.0, -0.1])), |
144 |
| -] |
145 |
| - |
146 |
| -scene_mesh = b3d.mesh.transform_and_merge_meshes( |
147 |
| - [object_mesh, background_mesh, object_mesh, object_mesh], |
148 |
| - object_poses, |
149 |
| -) |
150 |
| - |
151 |
| -viz_images = [] |
152 |
| -for t in tqdm(range(len(camera_poses_full))): |
153 |
| - b3d.utils.rr_set_time(t) |
154 |
| - rgbd = renderer.render_rgbd_from_mesh( |
155 |
| - scene_mesh.transform(camera_poses_full[t].inv()) |
| 139 | + background_mesh = b3d.mesh.Mesh.squeeze_mesh(meshes) |
| 140 | + background_mesh.rr_visualize("background_mesh") |
| 141 | + |
| 142 | + object_poses = [ |
| 143 | + object_pose, |
| 144 | + Pose.identity(), |
| 145 | + object_pose @ Pose.from_translation(jnp.array([-0.1, 0.0, 0.1])), |
| 146 | + object_pose @ Pose.from_translation(jnp.array([-0.1, 0.0, -0.1])), |
| 147 | + ] |
| 148 | + |
| 149 | + scene_mesh = b3d.mesh.transform_and_merge_meshes( |
| 150 | + [object_mesh, background_mesh, object_mesh, object_mesh], |
| 151 | + object_poses, |
156 | 152 | )
|
157 |
| - viz_images.append(b3d.viz_rgb(rgbd)) |
| 153 | + |
| 154 | + viz_images = [] |
| 155 | + for t in tqdm(range(len(camera_poses_full))): |
| 156 | + b3d.utils.rr_set_time(t) |
| 157 | + rgbd = renderer.render_rgbd_from_mesh( |
| 158 | + scene_mesh.transform(camera_poses_full[t].inv()) |
| 159 | + ) |
| 160 | + viz_images.append(b3d.viz_rgb(rgbd)) |
| 161 | + |
| 162 | + b3d.make_video_from_pil_images(viz_images, output_path, fps=30.0) |
| 163 | + print(f"Saved video to {output_path}") |
| 164 | + return output_path |
| 165 | + |
| 166 | + |
| 167 | +def main(): |
| 168 | + parser = argparse.ArgumentParser("acquire_object_mode") |
| 169 | + parser.add_argument("input", help="r3d file", type=str) |
| 170 | + args = parser.parse_args() |
| 171 | + filename = args.input |
| 172 | + return acquire(filename) |
158 | 173 |
|
159 | 174 |
|
160 |
| -b3d.make_video_from_pil_images(viz_images, filename + ".graphics_edits.mp4", fps=30.0) |
161 |
| -print(f"Saved video to {filename + '.graphics_edits.mp4'}") |
| 175 | +if __name__ == "__main__": |
| 176 | + main() |
0 commit comments