Skip to content

Commit 7f1fa15

Browse files
Merge pull request #45 from probcomp/sparse_gps_viz
Visualization of Particle System
2 parents 4571451 + 4d37506 commit 7f1fa15

File tree

2 files changed

+111
-9
lines changed

2 files changed

+111
-9
lines changed

b3d/chisight/shared/particle_system.py

+58-9
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from b3d.chisight.dense.dense_likelihood import make_dense_observation_model, DenseImageLikelihoodArgs
99
from b3d import Pose, Mesh
1010
from b3d.chisight.sparse.gps_utils import add_dummy_var
11-
from b3d.chisight.sparse.pose_utils import uniform_pose_in_ball
11+
from b3d.pose import uniform_pose_in_ball
1212
dummy_mapped_uniform_pose = add_dummy_var(uniform_pose_in_ball).vmap(in_axes=(0,None,None,None))
1313

1414

@@ -138,7 +138,7 @@ def latent_particle_model(
138138
return jax.tree.map(
139139
lambda t1, t2: jnp.concatenate([t1[None, :], t2], axis=0),
140140
init_retval, scan_retvals
141-
)
141+
), final_state
142142

143143
@genjax.gen
144144
def sparse_observation_model(particle_absolute_poses, camera_pose, visibility, instrinsics, sigma):
@@ -150,14 +150,14 @@ def sparse_observation_model(particle_absolute_poses, camera_pose, visibility, i
150150
@genjax.gen
151151
def sparse_gps_model(latent_particle_model_args, obs_model_args):
152152
# (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1)
153-
particle_dynamics_summary = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics"
153+
particle_dynamics_summary, final_state = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics"
154154
obs = sparse_observation_model.vmap(in_axes=(0, 0, 0, None, None))(
155155
particle_dynamics_summary["absolute_particle_poses"],
156156
particle_dynamics_summary["camera_pose"],
157157
particle_dynamics_summary["vis_mask"],
158158
*obs_model_args
159-
) @ "observation"
160-
return (particle_dynamics_summary, obs)
159+
) @ "obs"
160+
return (particle_dynamics_summary, final_state, obs)
161161

162162

163163

@@ -167,14 +167,63 @@ def make_dense_gps_model(renderer):
167167
@genjax.gen
168168
def dense_gps_model(latent_particle_model_args, dense_likelihood_args):
169169
# (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1)
170-
particle_dynamics_summary = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics"
170+
particle_dynamics_summary, final_state = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics"
171171
absolute_particle_poses_last_frame = particle_dynamics_summary["absolute_particle_poses"][-1]
172172
camera_pose_last_frame = particle_dynamics_summary["camera_pose"][-1]
173173
absolute_particle_poses_in_camera_frame = camera_pose_last_frame.inv() @ absolute_particle_poses_last_frame
174174

175175
(meshes, likelihood_args) = dense_likelihood_args
176176
merged_mesh = Mesh.transform_and_merge_meshes(meshes, absolute_particle_poses_in_camera_frame)
177-
image = dense_observation_model(merged_mesh, likelihood_args) @ "observation"
178-
return (particle_dynamics_summary, image)
177+
image = dense_observation_model(merged_mesh, likelihood_args) @ "obs"
178+
return (particle_dynamics_summary, final_state, image)
179+
180+
return dense_gps_model
181+
179182

180-
return dense_gps_model
183+
def visualize_particle_system(latent_particle_model_args, particle_dynamics_summary, final_state):
184+
import rerun as rr
185+
(dynamic_state, static_state) = final_state
186+
187+
(
188+
num_timesteps, # const object
189+
num_particles, # const object
190+
num_clusters, # const object
191+
relative_particle_poses_prior_params,
192+
initial_object_poses_prior_params,
193+
camera_pose_prior_params
194+
) = latent_particle_model_args
195+
196+
colors = b3d.distinct_colors(num_clusters.const)
197+
absolute_particle_poses = particle_dynamics_summary["absolute_particle_poses"]
198+
object_poses = particle_dynamics_summary["object_poses"]
199+
camera_pose = particle_dynamics_summary["camera_pose"]
200+
object_assignments = static_state[0]
201+
202+
cluster_colors = jnp.array(b3d.distinct_colors(num_clusters.const))
203+
204+
for t in range(num_timesteps.const):
205+
rr.set_time_sequence("time", t)
206+
207+
cam_pose = camera_pose[t]
208+
rr.log(
209+
f"/camera",
210+
rr.Transform3D(translation=cam_pose.position, rotation=rr.Quaternion(xyzw=cam_pose.xyzw)),
211+
)
212+
rr.log(
213+
f"/camera",
214+
rr.Pinhole(
215+
resolution=[0.1,0.1],
216+
focal_length=0.1,
217+
),
218+
)
219+
220+
rr.log(
221+
"absolute_particle_poses",
222+
rr.Points3D(
223+
absolute_particle_poses[t].pos,
224+
colors=cluster_colors[object_assignments]
225+
)
226+
)
227+
228+
for i in range(num_clusters.const):
229+
b3d.rr_log_pose(f"cluster/{i}", object_poses[t][i])

tests/test_chisight_sparse_gps.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import b3d
2+
from b3d.renderer.renderer_original import RendererOriginal
3+
from b3d.chisight.dense.dense_likelihood import DenseImageLikelihoodArgs, get_rgb_depth_inliers_from_observed_rendered_args
4+
import jax
5+
import jax.numpy as jnp
6+
import os
7+
from b3d import Pose, Mesh
8+
9+
import b3d.chisight.shared.particle_system as ps
10+
import genjax
11+
from genjax import Pytree
12+
import jax
13+
from b3d import Pose
14+
import b3d
15+
16+
17+
def test_sparse_gps_simulate():
18+
renderer = RendererOriginal()
19+
key = jax.random.PRNGKey(100)
20+
21+
num_timesteps = Pytree.const(10)
22+
num_particles = Pytree.const(100)
23+
num_clusters = Pytree.const(4)
24+
relative_particle_poses_prior_params = (Pose.identity(), .5, 0.25)
25+
initial_object_poses_prior_params = (Pose.identity(), 2., 0.5)
26+
camera_pose_prior_params = (Pose.identity(), 0.1, 0.1)
27+
instrinsics = Pytree.const(b3d.camera.Intrinsics(120, 100, 50., 50., 50., 50., 0.001, 16.))
28+
sigma_obs = 0.2
29+
30+
31+
b3d.rr_init()
32+
33+
trace = ps.sparse_gps_model.simulate(key, (
34+
(
35+
num_timesteps, # const object
36+
num_particles, # const object
37+
num_clusters, # const object
38+
relative_particle_poses_prior_params,
39+
initial_object_poses_prior_params,
40+
camera_pose_prior_params
41+
),
42+
(instrinsics, sigma_obs)
43+
))
44+
45+
46+
particle_dynamics_summary = trace.get_retval()[0]
47+
final_state = trace.get_retval()[1]
48+
latent_particle_model_args = trace.get_args()[0]
49+
50+
# import importlib
51+
# importlib.reload(b3d.chisight.shared.particle_system)
52+
53+
ps.visualize_particle_system(latent_particle_model_args, particle_dynamics_summary, final_state)

0 commit comments

Comments
 (0)