8
8
from b3d .chisight .dense .dense_likelihood import make_dense_observation_model , DenseImageLikelihoodArgs
9
9
from b3d import Pose , Mesh
10
10
from b3d .chisight .sparse .gps_utils import add_dummy_var
11
- from b3d .pose . pose_utils import uniform_pose_in_ball
11
+ from b3d .pose import uniform_pose_in_ball
12
12
dummy_mapped_uniform_pose = add_dummy_var (uniform_pose_in_ball ).vmap (in_axes = (0 ,None ,None ,None ))
13
13
14
14
@@ -122,9 +122,15 @@ def latent_particle_model(
122
122
camera_pose_prior_params
123
123
):
124
124
"""
125
- Retval is a dict with keys "relative_particle_poses", "absolute_particle_poses",
126
- "object_poses", "camera_poses", "vis_mask"
127
- Leading dimension for each timestep is the batch dimension.
125
+ The retval is a dict with keys "object_assignments" and "masked_dynamic_state".
126
+ The value at "masked_dynamic_state" is a genjax.Mask object `m`.
127
+ `m.value` is a dictionary with keys "relative_particle_poses", "absolute_particle_poses",
128
+ "object_poses", "camera_poses", "vis_mask".
129
+ The leading dimension for each will have size `max_num_timesteps`.
130
+ The boolean array `m.flag` will indicate which of these timesteps are valid
131
+ (and which are values >= `num_timesteps`).
132
+ The values at these invalid timesteps are undefined.
133
+ Using these values directly will cause silent errors.
128
134
"""
129
135
(state0 , init_retval ) = initial_particle_system_state (
130
136
num_particles , num_clusters ,
@@ -155,7 +161,14 @@ def latent_particle_model(
155
161
jnp .concatenate ([jnp .array ([True ]), masked_scan_retvals .flag ]),
156
162
concatenated_states_possibly_invalid
157
163
)
158
- return masked_concatenated_states
164
+
165
+ object_assignments = state0 [1 ][0 ]
166
+ latent_dynamics_summary = {
167
+ "object_assignments" : object_assignments ,
168
+ "masked_dynamic_state" : masked_concatenated_states ,
169
+ }
170
+
171
+ return latent_dynamics_summary
159
172
160
173
@genjax .gen
161
174
def sparse_observation_model (particle_absolute_poses , camera_pose , visibility , instrinsics , sigma ):
@@ -166,16 +179,17 @@ def sparse_observation_model(particle_absolute_poses, camera_pose, visibility, i
166
179
167
180
@genjax .gen
168
181
def sparse_gps_model (latent_particle_model_args , obs_model_args ):
169
- masked_particle_dynamics_summary = latent_particle_model (* latent_particle_model_args ) @ "particle_dynamics"
182
+ latent_dynamics_summary = latent_particle_model (* latent_particle_model_args ) @ "particle_dynamics"
183
+ masked_particle_dynamics_summary = latent_dynamics_summary ["masked_dynamic_state" ]
170
184
_UNSAFE_particle_dynamics_summary = masked_particle_dynamics_summary .value
171
185
masked_obs = sparse_observation_model .mask ().vmap (in_axes = (0 , 0 , 0 , 0 , None , None ))(
172
186
masked_particle_dynamics_summary .flag ,
173
187
_UNSAFE_particle_dynamics_summary ["absolute_particle_poses" ],
174
188
_UNSAFE_particle_dynamics_summary ["camera_pose" ],
175
189
_UNSAFE_particle_dynamics_summary ["vis_mask" ],
176
190
* obs_model_args
177
- ) @ "observation "
178
- return (masked_particle_dynamics_summary , masked_obs )
191
+ ) @ "obs "
192
+ return (latent_dynamics_summary , masked_obs )
179
193
180
194
181
195
@@ -184,7 +198,8 @@ def make_dense_gps_model(renderer):
184
198
185
199
@genjax .gen
186
200
def dense_gps_model (latent_particle_model_args , dense_likelihood_args ):
187
- masked_particle_dynamics_summary = latent_particle_model (* latent_particle_model_args ) @ "particle_dynamics"
201
+ latent_dynamics_summary = latent_particle_model (* latent_particle_model_args ) @ "particle_dynamics"
202
+ masked_particle_dynamics_summary = latent_dynamics_summary ["masked_dynamic_state" ]
188
203
_UNSAFE_particle_dynamics_summary = masked_particle_dynamics_summary .value
189
204
190
205
last_timestep_index = jnp .sum (masked_particle_dynamics_summary .flag ) - 1
@@ -194,7 +209,58 @@ def dense_gps_model(latent_particle_model_args, dense_likelihood_args):
194
209
195
210
(meshes , likelihood_args ) = dense_likelihood_args
196
211
merged_mesh = Mesh .transform_and_merge_meshes (meshes , absolute_particle_poses_in_camera_frame )
197
- image = dense_observation_model (merged_mesh , likelihood_args ) @ "observation"
198
- return (masked_particle_dynamics_summary , image )
212
+ image = dense_observation_model (merged_mesh , likelihood_args ) @ "obs"
213
+ return (latent_dynamics_summary , image )
214
+
215
+ return dense_gps_model
216
+
217
+
218
+ def visualize_particle_system (latent_particle_model_args , latent_dynamics_summary ):
219
+ import rerun as rr
220
+ (
221
+ max_num_timesteps , # const object
222
+ num_timesteps ,
223
+ num_particles , # const object
224
+ num_clusters , # const object
225
+ relative_particle_poses_prior_params ,
226
+ initial_object_poses_prior_params ,
227
+ camera_pose_prior_params
228
+ ) = latent_particle_model_args
229
+
230
+ colors = b3d .distinct_colors (num_clusters .const )
231
+
232
+ masked_particle_dynamics_summary = latent_dynamics_summary ["masked_dynamic_state" ]
233
+ object_assignments = latent_dynamics_summary ["object_assignments" ]
234
+ _UNSAFE_absolute_particle_poses = masked_particle_dynamics_summary .value ["absolute_particle_poses" ]
235
+ _UNSAFE_object_poses = masked_particle_dynamics_summary .value ["object_poses" ]
236
+ _UNSAFE_camera_pose = masked_particle_dynamics_summary .value ["camera_pose" ]
237
+
238
+ cluster_colors = jnp .array (b3d .distinct_colors (num_clusters .const ))
239
+
240
+ for t in range (num_timesteps ):
241
+ rr .set_time_sequence ("time" , t )
242
+ assert masked_particle_dynamics_summary .flag [t ], "Erroring before attempting to unmask invalid masked data."
243
+
244
+ cam_pose = _UNSAFE_camera_pose [t ]
245
+ rr .log (
246
+ f"/camera" ,
247
+ rr .Transform3D (translation = cam_pose .position , rotation = rr .Quaternion (xyzw = cam_pose .xyzw )),
248
+ )
249
+ rr .log (
250
+ f"/camera" ,
251
+ rr .Pinhole (
252
+ resolution = [0.1 ,0.1 ],
253
+ focal_length = 0.1 ,
254
+ ),
255
+ )
256
+
257
+ rr .log (
258
+ "absolute_particle_poses" ,
259
+ rr .Points3D (
260
+ _UNSAFE_absolute_particle_poses [t ].pos ,
261
+ colors = cluster_colors [object_assignments ]
262
+ )
263
+ )
199
264
200
- return dense_gps_model
265
+ for i in range (num_clusters .const ):
266
+ b3d .rr_log_pose (f"cluster/{ i } " , _UNSAFE_object_poses [t ][i ])
0 commit comments