8
8
from b3d .chisight .dense .dense_likelihood import make_dense_observation_model , DenseImageLikelihoodArgs
9
9
from b3d import Pose , Mesh
10
10
from b3d .chisight .sparse .gps_utils import add_dummy_var
11
+ < << << << HEAD
11
12
from b3d .pose import uniform_pose_in_ball
13
+ == == == =
14
+ from b3d .pose .pose_utils import uniform_pose_in_ball
15
+ > >> >> >> 5 a07ed3 (variable length unfold model )
12
16
dummy_mapped_uniform_pose = add_dummy_var (uniform_pose_in_ball ).vmap (in_axes = (0 ,None ,None ,None ))
13
17
14
18
@@ -113,7 +117,8 @@ def particle_system_state_step(carried_state, _):
113
117
114
118
@gen
115
119
def latent_particle_model (
116
- num_timesteps , # const object
120
+ max_num_timesteps , # const object
121
+ num_timesteps ,
117
122
num_particles , # const object
118
123
num_clusters , # const object
119
124
relative_particle_poses_prior_params ,
@@ -132,23 +137,45 @@ def latent_particle_model(
132
137
camera_pose_prior_params
133
138
) @ "state0"
134
139
135
- final_state , scan_retvals = particle_system_state_step .scan (n = (num_timesteps .const - 1 ))(state0 , None ) @ "states1+"
140
+ masked_final_state , masked_scan_retvals = b3d .modeling_utils .masked_scan_combinator (
141
+ particle_system_state_step ,
142
+ n = (max_num_timesteps .const - 1 )
143
+ )(
144
+ state0 ,
145
+ genjax .Mask (
146
+ # This next line tells the scan combinator how many timesteps to run
147
+ jnp .arange (max_num_timesteps .const - 1 ) < num_timesteps - 1 ,
148
+ jnp .zeros (max_num_timesteps .const - 1 )
149
+ )
150
+ ) @ "states1+"
151
+
136
152
137
153
# concatenate each element of init_retval, scan_retvals
138
- return jax .tree .map (
154
+ concatenated_states_possibly_invalid = jax .tree .map (
139
155
lambda t1 , t2 : jnp .concatenate ([t1 [None , :], t2 ], axis = 0 ),
156
+ << << << < HEAD
140
157
init_retval , scan_retvals
141
158
), final_state
159
+ == == == =
160
+ init_retval , masked_scan_retvals .value
161
+ )
162
+ masked_concatenated_states = genjax .Mask (
163
+ jnp .concatenate ([jnp .array ([True ]), masked_scan_retvals .flag ]),
164
+ concatenated_states_possibly_invalid
165
+ )
166
+ return masked_concatenated_states
167
+ >> >> >> > 5 a07ed3 (variable length unfold model )
142
168
143
169
@genjax .gen
144
170
def sparse_observation_model (particle_absolute_poses , camera_pose , visibility , instrinsics , sigma ):
145
171
# TODO: add visibility
146
172
uv = b3d .camera .screen_from_world (particle_absolute_poses .pos , camera_pose , instrinsics .const )
147
- uv_ = genjax .normal (uv , jnp .tile (sigma , uv .shape )) @ "sensor_coordinates"
173
+ uv_ = b3d . modeling_utils .normal (uv , jnp .tile (sigma , uv .shape )) @ "sensor_coordinates"
148
174
return uv_
149
175
150
176
@genjax .gen
151
177
def sparse_gps_model (latent_particle_model_args , obs_model_args ):
178
+ < << << << HEAD
152
179
# (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1)
153
180
particle_dynamics_summary , final_state = latent_particle_model (* latent_particle_model_args ) @ "particle_dynamics"
154
181
obs = sparse_observation_model .vmap (in_axes = (0 , 0 , 0 , None , None ))(
@@ -158,6 +185,18 @@ def sparse_gps_model(latent_particle_model_args, obs_model_args):
158
185
* obs_model_args
159
186
) @ "obs"
160
187
return (particle_dynamics_summary , final_state , obs )
188
+ == == == =
189
+ masked_particle_dynamics_summary = latent_particle_model (* latent_particle_model_args ) @ "particle_dynamics"
190
+ _UNSAFE_particle_dynamics_summary = masked_particle_dynamics_summary .value
191
+ masked_obs = sparse_observation_model .mask ().vmap (in_axes = (0 , 0 , 0 , 0 , None , None ))(
192
+ masked_particle_dynamics_summary .flag ,
193
+ _UNSAFE_particle_dynamics_summary ["absolute_particle_poses" ],
194
+ _UNSAFE_particle_dynamics_summary ["camera_pose" ],
195
+ _UNSAFE_particle_dynamics_summary ["vis_mask" ],
196
+ * obs_model_args
197
+ ) @ "observation"
198
+ return (masked_particle_dynamics_summary , masked_obs )
199
+ >> >> >> > 5 a07ed3 (variable length unfold model )
161
200
162
201
163
202
@@ -166,19 +205,33 @@ def make_dense_gps_model(renderer):
166
205
167
206
@genjax .gen
168
207
def dense_gps_model (latent_particle_model_args , dense_likelihood_args ):
208
+ < << << << HEAD
169
209
# (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1)
170
210
particle_dynamics_summary , final_state = latent_particle_model (* latent_particle_model_args ) @ "particle_dynamics"
171
211
absolute_particle_poses_last_frame = particle_dynamics_summary ["absolute_particle_poses" ][- 1 ]
172
212
camera_pose_last_frame = particle_dynamics_summary ["camera_pose" ][- 1 ]
213
+ == == == =
214
+ masked_particle_dynamics_summary = latent_particle_model (* latent_particle_model_args ) @ "particle_dynamics"
215
+ _UNSAFE_particle_dynamics_summary = masked_particle_dynamics_summary .value
216
+
217
+ last_timestep_index = jnp .sum (masked_particle_dynamics_summary .flag ) - 1
218
+ absolute_particle_poses_last_frame = _UNSAFE_particle_dynamics_summary ["absolute_particle_poses" ][last_timestep_index ]
219
+ camera_pose_last_frame = _UNSAFE_particle_dynamics_summary ["camera_pose" ][last_timestep_index ]
220
+ >> >> >> > 5 a07ed3 (variable length unfold model )
173
221
absolute_particle_poses_in_camera_frame = camera_pose_last_frame .inv () @ absolute_particle_poses_last_frame
174
222
175
223
(meshes , likelihood_args ) = dense_likelihood_args
176
224
merged_mesh = Mesh .transform_and_merge_meshes (meshes , absolute_particle_poses_in_camera_frame )
225
+ < << << << HEAD
177
226
image = dense_observation_model (merged_mesh , likelihood_args ) @ "obs"
178
227
return (particle_dynamics_summary , final_state , image )
179
228
180
229
return dense_gps_model
181
230
231
+ == == == =
232
+ image = dense_observation_model (merged_mesh , likelihood_args ) @ "observation"
233
+ return (masked_particle_dynamics_summary , image )
234
+ >> >> >> > 5 a07ed3 (variable length unfold model )
182
235
183
236
def visualize_particle_system (latent_particle_model_args , particle_dynamics_summary , final_state ):
184
237
import rerun as rr
0 commit comments