Skip to content

Commit f2dc6d5

Browse files
committed
dynamic length scan in particle system
2 parents 5a07ed3 + 7f1fa15 commit f2dc6d5

File tree

3 files changed

+151
-22
lines changed

3 files changed

+151
-22
lines changed

b3d/chisight/shared/particle_system.py

+78-12
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from b3d.chisight.dense.dense_likelihood import make_dense_observation_model, DenseImageLikelihoodArgs
99
from b3d import Pose, Mesh
1010
from b3d.chisight.sparse.gps_utils import add_dummy_var
11-
from b3d.pose.pose_utils import uniform_pose_in_ball
11+
from b3d.pose import uniform_pose_in_ball
1212
dummy_mapped_uniform_pose = add_dummy_var(uniform_pose_in_ball).vmap(in_axes=(0,None,None,None))
1313

1414

@@ -122,9 +122,15 @@ def latent_particle_model(
122122
camera_pose_prior_params
123123
):
124124
"""
125-
Retval is a dict with keys "relative_particle_poses", "absolute_particle_poses",
126-
"object_poses", "camera_poses", "vis_mask"
127-
Leading dimension for each timestep is the batch dimension.
125+
The retval is a dict with keys "object_assignments" and "masked_dynamic_state".
126+
The value at "masked_dynamic_state" is a genjax.Mask object `m`.
127+
`m.value` is a dictionary with keys "relative_particle_poses", "absolute_particle_poses",
128+
"object_poses", "camera_poses", "vis_mask".
129+
The leading dimension for each will have size `max_num_timesteps`.
130+
The boolean array `m.flag` will indicate which of these timesteps are valid
131+
(and which are values >= `num_timesteps`).
132+
The values at these invalid timesteps are undefined.
133+
Using these values directly will cause silent errors.
128134
"""
129135
(state0, init_retval) = initial_particle_system_state(
130136
num_particles, num_clusters,
@@ -155,7 +161,14 @@ def latent_particle_model(
155161
jnp.concatenate([jnp.array([True]), masked_scan_retvals.flag]),
156162
concatenated_states_possibly_invalid
157163
)
158-
return masked_concatenated_states
164+
165+
object_assignments = state0[1][0]
166+
latent_dynamics_summary = {
167+
"object_assignments": object_assignments,
168+
"masked_dynamic_state": masked_concatenated_states,
169+
}
170+
171+
return latent_dynamics_summary
159172

160173
@genjax.gen
161174
def sparse_observation_model(particle_absolute_poses, camera_pose, visibility, instrinsics, sigma):
@@ -166,16 +179,17 @@ def sparse_observation_model(particle_absolute_poses, camera_pose, visibility, i
166179

167180
@genjax.gen
168181
def sparse_gps_model(latent_particle_model_args, obs_model_args):
169-
masked_particle_dynamics_summary = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics"
182+
latent_dynamics_summary = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics"
183+
masked_particle_dynamics_summary = latent_dynamics_summary["masked_dynamic_state"]
170184
_UNSAFE_particle_dynamics_summary = masked_particle_dynamics_summary.value
171185
masked_obs = sparse_observation_model.mask().vmap(in_axes=(0, 0, 0, 0, None, None))(
172186
masked_particle_dynamics_summary.flag,
173187
_UNSAFE_particle_dynamics_summary["absolute_particle_poses"],
174188
_UNSAFE_particle_dynamics_summary["camera_pose"],
175189
_UNSAFE_particle_dynamics_summary["vis_mask"],
176190
*obs_model_args
177-
) @ "observation"
178-
return (masked_particle_dynamics_summary, masked_obs)
191+
) @ "obs"
192+
return (latent_dynamics_summary, masked_obs)
179193

180194

181195

@@ -184,7 +198,8 @@ def make_dense_gps_model(renderer):
184198

185199
@genjax.gen
186200
def dense_gps_model(latent_particle_model_args, dense_likelihood_args):
187-
masked_particle_dynamics_summary = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics"
201+
latent_dynamics_summary = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics"
202+
masked_particle_dynamics_summary = latent_dynamics_summary["masked_dynamic_state"]
188203
_UNSAFE_particle_dynamics_summary = masked_particle_dynamics_summary.value
189204

190205
last_timestep_index = jnp.sum(masked_particle_dynamics_summary.flag) - 1
@@ -194,7 +209,58 @@ def dense_gps_model(latent_particle_model_args, dense_likelihood_args):
194209

195210
(meshes, likelihood_args) = dense_likelihood_args
196211
merged_mesh = Mesh.transform_and_merge_meshes(meshes, absolute_particle_poses_in_camera_frame)
197-
image = dense_observation_model(merged_mesh, likelihood_args) @ "observation"
198-
return (masked_particle_dynamics_summary, image)
212+
image = dense_observation_model(merged_mesh, likelihood_args) @ "obs"
213+
return (latent_dynamics_summary, image)
214+
215+
return dense_gps_model
216+
217+
218+
def visualize_particle_system(latent_particle_model_args, latent_dynamics_summary):
219+
import rerun as rr
220+
(
221+
max_num_timesteps, # const object
222+
num_timesteps,
223+
num_particles, # const object
224+
num_clusters, # const object
225+
relative_particle_poses_prior_params,
226+
initial_object_poses_prior_params,
227+
camera_pose_prior_params
228+
) = latent_particle_model_args
229+
230+
colors = b3d.distinct_colors(num_clusters.const)
231+
232+
masked_particle_dynamics_summary = latent_dynamics_summary["masked_dynamic_state"]
233+
object_assignments = latent_dynamics_summary["object_assignments"]
234+
_UNSAFE_absolute_particle_poses = masked_particle_dynamics_summary.value["absolute_particle_poses"]
235+
_UNSAFE_object_poses = masked_particle_dynamics_summary.value["object_poses"]
236+
_UNSAFE_camera_pose = masked_particle_dynamics_summary.value["camera_pose"]
237+
238+
cluster_colors = jnp.array(b3d.distinct_colors(num_clusters.const))
239+
240+
for t in range(num_timesteps):
241+
rr.set_time_sequence("time", t)
242+
assert masked_particle_dynamics_summary.flag[t], "Erroring before attempting to unmask invalid masked data."
243+
244+
cam_pose = _UNSAFE_camera_pose[t]
245+
rr.log(
246+
f"/camera",
247+
rr.Transform3D(translation=cam_pose.position, rotation=rr.Quaternion(xyzw=cam_pose.xyzw)),
248+
)
249+
rr.log(
250+
f"/camera",
251+
rr.Pinhole(
252+
resolution=[0.1,0.1],
253+
focal_length=0.1,
254+
),
255+
)
256+
257+
rr.log(
258+
"absolute_particle_poses",
259+
rr.Points3D(
260+
_UNSAFE_absolute_particle_poses[t].pos,
261+
colors=cluster_colors[object_assignments]
262+
)
263+
)
199264

200-
return dense_gps_model
265+
for i in range(num_clusters.const):
266+
b3d.rr_log_pose(f"cluster/{i}", _UNSAFE_object_poses[t][i])

notebooks/integration.ipynb

+19-10
Large diffs are not rendered by default.

tests/test_chisight_sparse_gps.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import b3d
2+
from b3d.renderer.renderer_original import RendererOriginal
3+
from b3d.chisight.dense.dense_likelihood import DenseImageLikelihoodArgs, get_rgb_depth_inliers_from_observed_rendered_args
4+
import jax
5+
import jax.numpy as jnp
6+
import os
7+
from b3d import Pose, Mesh
8+
9+
import b3d.chisight.shared.particle_system as ps
10+
import genjax
11+
from genjax import Pytree
12+
import jax
13+
from b3d import Pose
14+
import b3d
15+
16+
17+
def test_sparse_gps_simulate():
18+
renderer = RendererOriginal()
19+
key = jax.random.PRNGKey(100)
20+
21+
max_num_timesteps = Pytree.const(20)
22+
num_timesteps = 10
23+
num_particles = Pytree.const(100)
24+
num_clusters = Pytree.const(4)
25+
relative_particle_poses_prior_params = (Pose.identity(), .5, 0.25)
26+
initial_object_poses_prior_params = (Pose.identity(), 2., 0.5)
27+
camera_pose_prior_params = (Pose.identity(), 0.1, 0.1)
28+
instrinsics = Pytree.const(b3d.camera.Intrinsics(120, 100, 50., 50., 50., 50., 0.001, 16.))
29+
sigma_obs = 0.2
30+
31+
32+
b3d.rr_init()
33+
34+
trace = ps.sparse_gps_model.simulate(key, (
35+
(
36+
max_num_timesteps, # const object
37+
num_timesteps,
38+
num_particles, # const object
39+
num_clusters, # const object
40+
relative_particle_poses_prior_params,
41+
initial_object_poses_prior_params,
42+
camera_pose_prior_params
43+
),
44+
(instrinsics, sigma_obs)
45+
))
46+
47+
48+
particle_system_summary = trace.get_retval()[0]
49+
latent_particle_model_args = trace.get_args()[0]
50+
51+
# import importlib
52+
# importlib.reload(b3d.chisight.shared.particle_system)
53+
54+
ps.visualize_particle_system(latent_particle_model_args, particle_system_summary)

0 commit comments

Comments
 (0)