Skip to content

Commit 4bb14fc

Browse files
authored
Make object acquisition demo callable by webapp (#81)
This PR: - moves the body of the script into a function I can call from elsewhere - lets the user supply an output path
1 parent 0256164 commit 4bb14fc

File tree

1 file changed

+161
-146
lines changed

1 file changed

+161
-146
lines changed

scripts/acquire_object_model.py

+161-146
Original file line numberDiff line numberDiff line change
@@ -8,154 +8,169 @@
88

99
b3d.rr_init("acquire_object_model")
1010

11+
# python scripts/acquire_object_model.py assets/shared_data_bucket/input_data/lysol_static.r3d
1112

12-
parser = argparse.ArgumentParser("acquire_object_mode")
13-
parser.add_argument("input", help="r3d file", type=str)
14-
args = parser.parse_args()
15-
16-
filename = args.input
17-
data = b3d.io.load_r3d(filename)
18-
19-
20-
_, _, fx, fy, cx, cy, near, far = data["camera_intrinsics_depth"]
21-
image_height, image_width = data["depth"].shape[1:3]
22-
num_scenes = data["depth"].shape[0]
23-
24-
indices = jnp.arange(0, num_scenes, 10)
25-
26-
camera_poses_full = data["camera_pose"]
27-
camera_poses = camera_poses_full[indices]
28-
29-
xyz = b3d.xyz_from_depth_vectorized(data["depth"][indices], fx, fy, cx, cy)
30-
xyz_world_frame = camera_poses[:, None, None].apply(xyz)
31-
32-
# for i in range(len(xyz_world_frame)):
33-
# b3d.rr_set_time(i)
34-
# b3d.utils.rr_log_cloud("xyz", xyz_world_frame[i])
35-
36-
# Resize rgbs to be same size as depth.
37-
rgbs = data["rgb"]
38-
rgbs_resized = jnp.clip(
39-
jax.vmap(jax.image.resize, in_axes=(0, None, None))(
40-
rgbs[indices] / 255.0,
41-
(image_height, image_width, 3),
42-
"linear",
43-
),
44-
0.0,
45-
1.0,
46-
)
47-
48-
49-
masks = [b3d.carvekit_get_foreground_mask(r) for r in rgbs_resized]
50-
masks_concat = jnp.stack(masks, axis=0)
51-
52-
grid_center = jnp.median(camera_poses[0].apply(xyz[0][masks[0]]), axis=0)
53-
W = 0.3
54-
D = 100
55-
grid = jnp.stack(
56-
jnp.meshgrid(
57-
jnp.linspace(grid_center[0] - W / 2, grid_center[0] + W / 2, D),
58-
jnp.linspace(grid_center[1] - W / 2, grid_center[1] + W / 2, D),
59-
jnp.linspace(grid_center[2] - W / 2, grid_center[2] + W / 2, D),
60-
),
61-
axis=-1,
62-
).reshape(-1, 3)
63-
64-
occ_free_occl_, colors_per_voxel_ = (
65-
b3d.voxel_occupied_occluded_free_parallel_camera_depth(
66-
camera_poses,
67-
rgbs_resized,
68-
xyz[..., 2] * masks_concat + (1.0 - masks_concat) * 5.0,
69-
grid,
70-
fx,
71-
fy,
72-
cx,
73-
cy,
74-
6.0,
75-
0.005,
13+
14+
# ssh sam-b3d-l4.us-west1-a.probcomp-caliban -L 5000:localhost:5000
15+
16+
17+
def acquire(input_path, output_path=None):
18+
if output_path is None:
19+
output_path = input_path + ".graphics_edits.mp4"
20+
21+
data = b3d.io.load_r3d(input_path)
22+
23+
_, _, fx, fy, cx, cy, near, far = data["camera_intrinsics_depth"]
24+
image_height, image_width = data["depth"].shape[1:3]
25+
num_scenes = data["depth"].shape[0]
26+
27+
indices = jnp.arange(0, num_scenes, 10)
28+
29+
camera_poses_full = data["camera_pose"]
30+
camera_poses = camera_poses_full[indices]
31+
32+
xyz = b3d.xyz_from_depth_vectorized(data["depth"][indices], fx, fy, cx, cy)
33+
xyz_world_frame = camera_poses[:, None, None].apply(xyz)
34+
35+
# for i in range(len(xyz_world_frame)):
36+
# b3d.rr_set_time(i)
37+
# b3d.utils.rr_log_cloud("xyz", xyz_world_frame[i])
38+
39+
# Resize rgbs to be same size as depth.
40+
rgbs = data["rgb"]
41+
rgbs_resized = jnp.clip(
42+
jax.vmap(jax.image.resize, in_axes=(0, None, None))(
43+
rgbs[indices] / 255.0,
44+
(image_height, image_width, 3),
45+
"linear",
46+
),
47+
0.0,
48+
1.0,
49+
)
50+
51+
masks = [b3d.carvekit_get_foreground_mask(r) for r in rgbs_resized]
52+
masks_concat = jnp.stack(masks, axis=0)
53+
54+
grid_center = jnp.median(camera_poses[0].apply(xyz[0][masks[0]]), axis=0)
55+
W = 0.3
56+
D = 100
57+
grid = jnp.stack(
58+
jnp.meshgrid(
59+
jnp.linspace(grid_center[0] - W / 2, grid_center[0] + W / 2, D),
60+
jnp.linspace(grid_center[1] - W / 2, grid_center[1] + W / 2, D),
61+
jnp.linspace(grid_center[2] - W / 2, grid_center[2] + W / 2, D),
62+
),
63+
axis=-1,
64+
).reshape(-1, 3)
65+
66+
occ_free_occl_, colors_per_voxel_ = (
67+
b3d.voxel_occupied_occluded_free_parallel_camera_depth(
68+
camera_poses,
69+
rgbs_resized,
70+
xyz[..., 2] * masks_concat + (1.0 - masks_concat) * 5.0,
71+
grid,
72+
fx,
73+
fy,
74+
cx,
75+
cy,
76+
6.0,
77+
0.005,
78+
)
79+
)
80+
i = len(occ_free_occl_)
81+
occ_free_occl, colors_per_voxel = occ_free_occl_[:i], colors_per_voxel_[:i]
82+
total_occ = (occ_free_occl == 1.0).sum(0)
83+
total_free = (occ_free_occl == -1.0).sum(0)
84+
ratio = total_occ / (total_occ + total_free) * ((total_occ + total_free) > 1)
85+
86+
grid_colors = colors_per_voxel.sum(0) / (total_occ[..., None])
87+
model_mask = ratio > 0.2
88+
89+
resolution = 0.0015
90+
91+
grid_points = grid[model_mask]
92+
colors = grid_colors[model_mask]
93+
94+
meshes = b3d.mesh.transform_mesh(
95+
jax.vmap(b3d.mesh.Mesh.cube_mesh)(
96+
jnp.ones((grid_points.shape[0], 3)) * resolution * 2.0, colors
97+
),
98+
b3d.Pose.from_translation(grid_points)[:, None],
99+
)
100+
_object_mesh = b3d.mesh.Mesh.squeeze_mesh(meshes)
101+
102+
object_pose = Pose.from_translation(jnp.median(_object_mesh.vertices, axis=0))
103+
object_mesh = _object_mesh.transform(object_pose.inv())
104+
object_mesh.rr_visualize("mesh")
105+
106+
mesh_filename = input_path + ".mesh.obj"
107+
# Save the mesh
108+
print(f"Saving obj file to {mesh_filename}")
109+
object_mesh.save(mesh_filename)
110+
111+
renderer = b3d.RendererOriginal(
112+
image_width, image_height, fx, fy, cx, cy, near, far
113+
)
114+
rgbds = renderer.render_rgbd_many(
115+
(camera_poses[:, None].inv() @ object_pose).apply(object_mesh.vertices),
116+
object_mesh.faces,
117+
jnp.tile(object_mesh.vertex_attributes, (len(camera_poses), 1, 1)),
118+
)
119+
120+
sub_indices = jnp.array([0, 5, len(camera_poses) - 15, len(camera_poses) - 5])
121+
mask = rgbds[sub_indices, ..., 3] == 0.0
122+
123+
background_xyzs = xyz_world_frame[sub_indices][mask]
124+
colors = rgbs_resized[sub_indices][mask, :]
125+
distances_from_camera = xyz[sub_indices][..., 2][mask][..., None] / fx
126+
127+
# subset = jax.random.choice(jax.random.PRNGKey(0), jnp.arange(background_xyzs.shape[0]), shape=(background_xyzs.shape[0]//3,), replace=False)
128+
129+
# background_xyzs = background_xyzs[subset]
130+
# colors = colors[subset]
131+
# distances_from_camera = distances_from_camera[subset]
132+
133+
meshes = b3d.mesh.transform_mesh(
134+
jax.vmap(b3d.mesh.Mesh.cube_mesh)(
135+
jnp.ones((background_xyzs.shape[0], 3)) * distances_from_camera, colors
136+
),
137+
b3d.Pose.from_translation(background_xyzs)[:, None],
76138
)
77-
)
78-
i = len(occ_free_occl_)
79-
occ_free_occl, colors_per_voxel = occ_free_occl_[:i], colors_per_voxel_[:i]
80-
total_occ = (occ_free_occl == 1.0).sum(0)
81-
total_free = (occ_free_occl == -1.0).sum(0)
82-
ratio = total_occ / (total_occ + total_free) * ((total_occ + total_free) > 1)
83-
84-
grid_colors = colors_per_voxel.sum(0) / (total_occ[..., None])
85-
model_mask = ratio > 0.2
86-
87-
resolution = 0.0015
88-
89-
grid_points = grid[model_mask]
90-
colors = grid_colors[model_mask]
91-
92-
meshes = b3d.mesh.transform_mesh(
93-
jax.vmap(b3d.mesh.Mesh.cube_mesh)(
94-
jnp.ones((grid_points.shape[0], 3)) * resolution * 2.0, colors
95-
),
96-
b3d.Pose.from_translation(grid_points)[:, None],
97-
)
98-
_object_mesh = b3d.mesh.Mesh.squeeze_mesh(meshes)
99-
100-
object_pose = Pose.from_translation(jnp.median(_object_mesh.vertices, axis=0))
101-
object_mesh = _object_mesh.transform(object_pose.inv())
102-
object_mesh.rr_visualize("mesh")
103-
104-
mesh_filename = filename + ".mesh.obj"
105-
# Save the mesh
106-
print(f"Saving obj file to {mesh_filename}")
107-
object_mesh.save(mesh_filename)
108-
109-
renderer = b3d.RendererOriginal(image_width, image_height, fx, fy, cx, cy, near, far)
110-
rgbds = renderer.render_rgbd_many(
111-
(camera_poses[:, None].inv() @ object_pose).apply(object_mesh.vertices),
112-
object_mesh.faces,
113-
jnp.tile(object_mesh.vertex_attributes, (len(camera_poses), 1, 1)),
114-
)
115-
116-
sub_indices = jnp.array([0, 5, len(camera_poses) - 15, len(camera_poses) - 5])
117-
mask = rgbds[sub_indices, ..., 3] == 0.0
118-
119-
background_xyzs = xyz_world_frame[sub_indices][mask]
120-
colors = rgbs_resized[sub_indices][mask, :]
121-
distances_from_camera = xyz[sub_indices][..., 2][mask][..., None] / fx
122-
123-
# subset = jax.random.choice(jax.random.PRNGKey(0), jnp.arange(background_xyzs.shape[0]), shape=(background_xyzs.shape[0]//3,), replace=False)
124-
125-
# background_xyzs = background_xyzs[subset]
126-
# colors = colors[subset]
127-
# distances_from_camera = distances_from_camera[subset]
128-
129-
meshes = b3d.mesh.transform_mesh(
130-
jax.vmap(b3d.mesh.Mesh.cube_mesh)(
131-
jnp.ones((background_xyzs.shape[0], 3)) * distances_from_camera, colors
132-
),
133-
b3d.Pose.from_translation(background_xyzs)[:, None],
134-
)
135-
background_mesh = b3d.mesh.Mesh.squeeze_mesh(meshes)
136-
background_mesh.rr_visualize("background_mesh")
137-
138-
139-
object_poses = [
140-
object_pose,
141-
Pose.identity(),
142-
object_pose @ Pose.from_translation(jnp.array([-0.1, 0.0, 0.1])),
143-
object_pose @ Pose.from_translation(jnp.array([-0.1, 0.0, -0.1])),
144-
]
145-
146-
scene_mesh = b3d.mesh.transform_and_merge_meshes(
147-
[object_mesh, background_mesh, object_mesh, object_mesh],
148-
object_poses,
149-
)
150-
151-
viz_images = []
152-
for t in tqdm(range(len(camera_poses_full))):
153-
b3d.utils.rr_set_time(t)
154-
rgbd = renderer.render_rgbd_from_mesh(
155-
scene_mesh.transform(camera_poses_full[t].inv())
139+
background_mesh = b3d.mesh.Mesh.squeeze_mesh(meshes)
140+
background_mesh.rr_visualize("background_mesh")
141+
142+
object_poses = [
143+
object_pose,
144+
Pose.identity(),
145+
object_pose @ Pose.from_translation(jnp.array([-0.1, 0.0, 0.1])),
146+
object_pose @ Pose.from_translation(jnp.array([-0.1, 0.0, -0.1])),
147+
]
148+
149+
scene_mesh = b3d.mesh.transform_and_merge_meshes(
150+
[object_mesh, background_mesh, object_mesh, object_mesh],
151+
object_poses,
156152
)
157-
viz_images.append(b3d.viz_rgb(rgbd))
153+
154+
viz_images = []
155+
for t in tqdm(range(len(camera_poses_full))):
156+
b3d.utils.rr_set_time(t)
157+
rgbd = renderer.render_rgbd_from_mesh(
158+
scene_mesh.transform(camera_poses_full[t].inv())
159+
)
160+
viz_images.append(b3d.viz_rgb(rgbd))
161+
162+
b3d.make_video_from_pil_images(viz_images, output_path, fps=30.0)
163+
print(f"Saved video to {output_path}")
164+
return output_path
165+
166+
167+
def main():
168+
parser = argparse.ArgumentParser("acquire_object_mode")
169+
parser.add_argument("input", help="r3d file", type=str)
170+
args = parser.parse_args()
171+
filename = args.input
172+
return acquire(filename)
158173

159174

160-
b3d.make_video_from_pil_images(viz_images, filename + ".graphics_edits.mp4", fps=30.0)
161-
print(f"Saved video to {filename + '.graphics_edits.mp4'}")
175+
if __name__ == "__main__":
176+
main()

0 commit comments

Comments
 (0)