8
8
from b3d .chisight .dense .dense_likelihood import make_dense_observation_model , DenseImageLikelihoodArgs
9
9
from b3d import Pose , Mesh
10
10
from b3d .chisight .sparse .gps_utils import add_dummy_var
11
- from b3d .chisight . sparse .pose_utils import uniform_pose_in_ball
11
+ from b3d .pose .pose_utils import uniform_pose_in_ball
12
12
dummy_mapped_uniform_pose = add_dummy_var (uniform_pose_in_ball ).vmap (in_axes = (0 ,None ,None ,None ))
13
13
14
14
@@ -113,7 +113,8 @@ def particle_system_state_step(carried_state, _):
113
113
114
114
@gen
115
115
def latent_particle_model (
116
- num_timesteps , # const object
116
+ max_num_timesteps , # const object
117
+ num_timesteps ,
117
118
num_particles , # const object
118
119
num_clusters , # const object
119
120
relative_particle_poses_prior_params ,
@@ -132,32 +133,49 @@ def latent_particle_model(
132
133
camera_pose_prior_params
133
134
) @ "state0"
134
135
135
- final_state , scan_retvals = particle_system_state_step .scan (n = (num_timesteps .const - 1 ))(state0 , None ) @ "states1+"
136
+ masked_final_state , masked_scan_retvals = b3d .modeling_utils .masked_scan_combinator (
137
+ particle_system_state_step ,
138
+ n = (max_num_timesteps .const - 1 )
139
+ )(
140
+ state0 ,
141
+ genjax .Mask (
142
+ # This next line tells the scan combinator how many timesteps to run
143
+ jnp .arange (max_num_timesteps .const - 1 ) < num_timesteps - 1 ,
144
+ jnp .zeros (max_num_timesteps .const - 1 )
145
+ )
146
+ ) @ "states1+"
147
+
136
148
137
149
# concatenate each element of init_retval, scan_retvals
138
- return jax .tree .map (
150
+ concatenated_states_possibly_invalid = jax .tree .map (
139
151
lambda t1 , t2 : jnp .concatenate ([t1 [None , :], t2 ], axis = 0 ),
140
- init_retval , scan_retvals
152
+ init_retval , masked_scan_retvals .value
153
+ )
154
+ masked_concatenated_states = genjax .Mask (
155
+ jnp .concatenate ([jnp .array ([True ]), masked_scan_retvals .flag ]),
156
+ concatenated_states_possibly_invalid
141
157
)
158
+ return masked_concatenated_states
142
159
143
160
@genjax .gen
144
161
def sparse_observation_model (particle_absolute_poses , camera_pose , visibility , instrinsics , sigma ):
145
162
# TODO: add visibility
146
163
uv = b3d .camera .screen_from_world (particle_absolute_poses .pos , camera_pose , instrinsics .const )
147
- uv_ = genjax .normal (uv , jnp .tile (sigma , uv .shape )) @ "sensor_coordinates"
164
+ uv_ = b3d . modeling_utils .normal (uv , jnp .tile (sigma , uv .shape )) @ "sensor_coordinates"
148
165
return uv_
149
166
150
167
@genjax .gen
151
168
def sparse_gps_model (latent_particle_model_args , obs_model_args ):
152
- # (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1)
153
- particle_dynamics_summary = latent_particle_model (* latent_particle_model_args ) @ "particle_dynamics"
154
- obs = sparse_observation_model .vmap (in_axes = (0 , 0 , 0 , None , None ))(
155
- particle_dynamics_summary ["absolute_particle_poses" ],
156
- particle_dynamics_summary ["camera_pose" ],
157
- particle_dynamics_summary ["vis_mask" ],
169
+ masked_particle_dynamics_summary = latent_particle_model (* latent_particle_model_args ) @ "particle_dynamics"
170
+ _UNSAFE_particle_dynamics_summary = masked_particle_dynamics_summary .value
171
+ masked_obs = sparse_observation_model .mask ().vmap (in_axes = (0 , 0 , 0 , 0 , None , None ))(
172
+ masked_particle_dynamics_summary .flag ,
173
+ _UNSAFE_particle_dynamics_summary ["absolute_particle_poses" ],
174
+ _UNSAFE_particle_dynamics_summary ["camera_pose" ],
175
+ _UNSAFE_particle_dynamics_summary ["vis_mask" ],
158
176
* obs_model_args
159
177
) @ "observation"
160
- return (particle_dynamics_summary , obs )
178
+ return (masked_particle_dynamics_summary , masked_obs )
161
179
162
180
163
181
@@ -166,15 +184,17 @@ def make_dense_gps_model(renderer):
166
184
167
185
@genjax .gen
168
186
def dense_gps_model (latent_particle_model_args , dense_likelihood_args ):
169
- # (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1)
170
- particle_dynamics_summary = latent_particle_model (* latent_particle_model_args ) @ "particle_dynamics"
171
- absolute_particle_poses_last_frame = particle_dynamics_summary ["absolute_particle_poses" ][- 1 ]
172
- camera_pose_last_frame = particle_dynamics_summary ["camera_pose" ][- 1 ]
187
+ masked_particle_dynamics_summary = latent_particle_model (* latent_particle_model_args ) @ "particle_dynamics"
188
+ _UNSAFE_particle_dynamics_summary = masked_particle_dynamics_summary .value
189
+
190
+ last_timestep_index = jnp .sum (masked_particle_dynamics_summary .flag ) - 1
191
+ absolute_particle_poses_last_frame = _UNSAFE_particle_dynamics_summary ["absolute_particle_poses" ][last_timestep_index ]
192
+ camera_pose_last_frame = _UNSAFE_particle_dynamics_summary ["camera_pose" ][last_timestep_index ]
173
193
absolute_particle_poses_in_camera_frame = camera_pose_last_frame .inv () @ absolute_particle_poses_last_frame
174
194
175
195
(meshes , likelihood_args ) = dense_likelihood_args
176
196
merged_mesh = Mesh .transform_and_merge_meshes (meshes , absolute_particle_poses_in_camera_frame )
177
197
image = dense_observation_model (merged_mesh , likelihood_args ) @ "observation"
178
- return (particle_dynamics_summary , image )
198
+ return (masked_particle_dynamics_summary , image )
179
199
180
200
return dense_gps_model
0 commit comments