8
8
from b3d .chisight .dense .dense_likelihood import make_dense_observation_model , DenseImageLikelihoodArgs
9
9
from b3d import Pose , Mesh
10
10
from b3d .chisight .sparse .gps_utils import add_dummy_var
11
- from b3d .chisight . sparse . pose_utils import uniform_pose_in_ball
11
+ from b3d .pose import uniform_pose_in_ball
12
12
dummy_mapped_uniform_pose = add_dummy_var (uniform_pose_in_ball ).vmap (in_axes = (0 ,None ,None ,None ))
13
13
14
14
@@ -138,7 +138,7 @@ def latent_particle_model(
138
138
return jax .tree .map (
139
139
lambda t1 , t2 : jnp .concatenate ([t1 [None , :], t2 ], axis = 0 ),
140
140
init_retval , scan_retvals
141
- )
141
+ ), final_state
142
142
143
143
@genjax .gen
144
144
def sparse_observation_model (particle_absolute_poses , camera_pose , visibility , instrinsics , sigma ):
@@ -150,14 +150,14 @@ def sparse_observation_model(particle_absolute_poses, camera_pose, visibility, i
150
150
@genjax .gen
151
151
def sparse_gps_model (latent_particle_model_args , obs_model_args ):
152
152
# (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1)
153
- particle_dynamics_summary = latent_particle_model (* latent_particle_model_args ) @ "particle_dynamics"
153
+ particle_dynamics_summary , final_state = latent_particle_model (* latent_particle_model_args ) @ "particle_dynamics"
154
154
obs = sparse_observation_model .vmap (in_axes = (0 , 0 , 0 , None , None ))(
155
155
particle_dynamics_summary ["absolute_particle_poses" ],
156
156
particle_dynamics_summary ["camera_pose" ],
157
157
particle_dynamics_summary ["vis_mask" ],
158
158
* obs_model_args
159
- ) @ "observation "
160
- return (particle_dynamics_summary , obs )
159
+ ) @ "obs "
160
+ return (particle_dynamics_summary , final_state , obs )
161
161
162
162
163
163
@@ -167,14 +167,63 @@ def make_dense_gps_model(renderer):
167
167
@genjax .gen
168
168
def dense_gps_model (latent_particle_model_args , dense_likelihood_args ):
169
169
# (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1)
170
- particle_dynamics_summary = latent_particle_model (* latent_particle_model_args ) @ "particle_dynamics"
170
+ particle_dynamics_summary , final_state = latent_particle_model (* latent_particle_model_args ) @ "particle_dynamics"
171
171
absolute_particle_poses_last_frame = particle_dynamics_summary ["absolute_particle_poses" ][- 1 ]
172
172
camera_pose_last_frame = particle_dynamics_summary ["camera_pose" ][- 1 ]
173
173
absolute_particle_poses_in_camera_frame = camera_pose_last_frame .inv () @ absolute_particle_poses_last_frame
174
174
175
175
(meshes , likelihood_args ) = dense_likelihood_args
176
176
merged_mesh = Mesh .transform_and_merge_meshes (meshes , absolute_particle_poses_in_camera_frame )
177
- image = dense_observation_model (merged_mesh , likelihood_args ) @ "observation"
178
- return (particle_dynamics_summary , image )
177
+ image = dense_observation_model (merged_mesh , likelihood_args ) @ "obs"
178
+ return (particle_dynamics_summary , final_state , image )
179
+
180
+ return dense_gps_model
181
+
179
182
180
- return dense_gps_model
183
+ def visualize_particle_system (latent_particle_model_args , particle_dynamics_summary , final_state ):
184
+ import rerun as rr
185
+ (dynamic_state , static_state ) = final_state
186
+
187
+ (
188
+ num_timesteps , # const object
189
+ num_particles , # const object
190
+ num_clusters , # const object
191
+ relative_particle_poses_prior_params ,
192
+ initial_object_poses_prior_params ,
193
+ camera_pose_prior_params
194
+ ) = latent_particle_model_args
195
+
196
+ colors = b3d .distinct_colors (num_clusters .const )
197
+ absolute_particle_poses = particle_dynamics_summary ["absolute_particle_poses" ]
198
+ object_poses = particle_dynamics_summary ["object_poses" ]
199
+ camera_pose = particle_dynamics_summary ["camera_pose" ]
200
+ object_assignments = static_state [0 ]
201
+
202
+ cluster_colors = jnp .array (b3d .distinct_colors (num_clusters .const ))
203
+
204
+ for t in range (num_timesteps .const ):
205
+ rr .set_time_sequence ("time" , t )
206
+
207
+ cam_pose = camera_pose [t ]
208
+ rr .log (
209
+ f"/camera" ,
210
+ rr .Transform3D (translation = cam_pose .position , rotation = rr .Quaternion (xyzw = cam_pose .xyzw )),
211
+ )
212
+ rr .log (
213
+ f"/camera" ,
214
+ rr .Pinhole (
215
+ resolution = [0.1 ,0.1 ],
216
+ focal_length = 0.1 ,
217
+ ),
218
+ )
219
+
220
+ rr .log (
221
+ "absolute_particle_poses" ,
222
+ rr .Points3D (
223
+ absolute_particle_poses [t ].pos ,
224
+ colors = cluster_colors [object_assignments ]
225
+ )
226
+ )
227
+
228
+ for i in range (num_clusters .const ):
229
+ b3d .rr_log_pose (f"cluster/{ i } " , object_poses [t ][i ])
0 commit comments