Skip to content

Commit ed3d5ca

Browse files
committed
variable length unfold model
1 parent 5709bdc commit ed3d5ca

File tree

3 files changed

+182
-30
lines changed

3 files changed

+182
-30
lines changed

b3d/chisight/shared/particle_system.py

+57-4
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
from b3d.chisight.dense.dense_likelihood import make_dense_observation_model, DenseImageLikelihoodArgs
99
from b3d import Pose, Mesh
1010
from b3d.chisight.sparse.gps_utils import add_dummy_var
11+
<<<<<<< HEAD
1112
from b3d.pose import uniform_pose_in_ball
13+
=======
14+
from b3d.pose.pose_utils import uniform_pose_in_ball
15+
>>>>>>> 5a07ed3 (variable length unfold model)
1216
dummy_mapped_uniform_pose = add_dummy_var(uniform_pose_in_ball).vmap(in_axes=(0,None,None,None))
1317

1418

@@ -113,7 +117,8 @@ def particle_system_state_step(carried_state, _):
113117

114118
@gen
115119
def latent_particle_model(
116-
num_timesteps, # const object
120+
max_num_timesteps, # const object
121+
num_timesteps,
117122
num_particles, # const object
118123
num_clusters, # const object
119124
relative_particle_poses_prior_params,
@@ -132,23 +137,45 @@ def latent_particle_model(
132137
camera_pose_prior_params
133138
) @ "state0"
134139

135-
final_state, scan_retvals = particle_system_state_step.scan(n=(num_timesteps.const - 1))(state0, None) @ "states1+"
140+
masked_final_state, masked_scan_retvals = b3d.modeling_utils.masked_scan_combinator(
141+
particle_system_state_step,
142+
n=(max_num_timesteps.const-1)
143+
)(
144+
state0,
145+
genjax.Mask(
146+
# This next line tells the scan combinator how many timesteps to run
147+
jnp.arange(max_num_timesteps.const - 1) < num_timesteps - 1,
148+
jnp.zeros(max_num_timesteps.const - 1)
149+
)
150+
) @ "states1+"
151+
136152

137153
# concatenate each element of init_retval, scan_retvals
138-
return jax.tree.map(
154+
concatenated_states_possibly_invalid = jax.tree.map(
139155
lambda t1, t2: jnp.concatenate([t1[None, :], t2], axis=0),
156+
<<<<<<< HEAD
140157
init_retval, scan_retvals
141158
), final_state
159+
=======
160+
init_retval, masked_scan_retvals.value
161+
)
162+
masked_concatenated_states = genjax.Mask(
163+
jnp.concatenate([jnp.array([True]), masked_scan_retvals.flag]),
164+
concatenated_states_possibly_invalid
165+
)
166+
return masked_concatenated_states
167+
>>>>>>> 5a07ed3 (variable length unfold model)
142168

143169
@genjax.gen
144170
def sparse_observation_model(particle_absolute_poses, camera_pose, visibility, instrinsics, sigma):
145171
# TODO: add visibility
146172
uv = b3d.camera.screen_from_world(particle_absolute_poses.pos, camera_pose, instrinsics.const)
147-
uv_ = genjax.normal(uv, jnp.tile(sigma, uv.shape)) @ "sensor_coordinates"
173+
uv_ = b3d.modeling_utils.normal(uv, jnp.tile(sigma, uv.shape)) @ "sensor_coordinates"
148174
return uv_
149175

150176
@genjax.gen
151177
def sparse_gps_model(latent_particle_model_args, obs_model_args):
178+
<<<<<<< HEAD
152179
# (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1)
153180
particle_dynamics_summary, final_state = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics"
154181
obs = sparse_observation_model.vmap(in_axes=(0, 0, 0, None, None))(
@@ -158,6 +185,18 @@ def sparse_gps_model(latent_particle_model_args, obs_model_args):
158185
*obs_model_args
159186
) @ "obs"
160187
return (particle_dynamics_summary, final_state, obs)
188+
=======
189+
masked_particle_dynamics_summary = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics"
190+
_UNSAFE_particle_dynamics_summary = masked_particle_dynamics_summary.value
191+
masked_obs = sparse_observation_model.mask().vmap(in_axes=(0, 0, 0, 0, None, None))(
192+
masked_particle_dynamics_summary.flag,
193+
_UNSAFE_particle_dynamics_summary["absolute_particle_poses"],
194+
_UNSAFE_particle_dynamics_summary["camera_pose"],
195+
_UNSAFE_particle_dynamics_summary["vis_mask"],
196+
*obs_model_args
197+
) @ "observation"
198+
return (masked_particle_dynamics_summary, masked_obs)
199+
>>>>>>> 5a07ed3 (variable length unfold model)
161200

162201

163202

@@ -166,19 +205,33 @@ def make_dense_gps_model(renderer):
166205

167206
@genjax.gen
168207
def dense_gps_model(latent_particle_model_args, dense_likelihood_args):
208+
<<<<<<< HEAD
169209
# (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1)
170210
particle_dynamics_summary, final_state = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics"
171211
absolute_particle_poses_last_frame = particle_dynamics_summary["absolute_particle_poses"][-1]
172212
camera_pose_last_frame = particle_dynamics_summary["camera_pose"][-1]
213+
=======
214+
masked_particle_dynamics_summary = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics"
215+
_UNSAFE_particle_dynamics_summary = masked_particle_dynamics_summary.value
216+
217+
last_timestep_index = jnp.sum(masked_particle_dynamics_summary.flag) - 1
218+
absolute_particle_poses_last_frame = _UNSAFE_particle_dynamics_summary["absolute_particle_poses"][last_timestep_index]
219+
camera_pose_last_frame = _UNSAFE_particle_dynamics_summary["camera_pose"][last_timestep_index]
220+
>>>>>>> 5a07ed3 (variable length unfold model)
173221
absolute_particle_poses_in_camera_frame = camera_pose_last_frame.inv() @ absolute_particle_poses_last_frame
174222

175223
(meshes, likelihood_args) = dense_likelihood_args
176224
merged_mesh = Mesh.transform_and_merge_meshes(meshes, absolute_particle_poses_in_camera_frame)
225+
<<<<<<< HEAD
177226
image = dense_observation_model(merged_mesh, likelihood_args) @ "obs"
178227
return (particle_dynamics_summary, final_state, image)
179228

180229
return dense_gps_model
181230

231+
=======
232+
image = dense_observation_model(merged_mesh, likelihood_args) @ "observation"
233+
return (masked_particle_dynamics_summary, image)
234+
>>>>>>> 5a07ed3 (variable length unfold model)
182235

183236
def visualize_particle_system(latent_particle_model_args, particle_dynamics_summary, final_state):
184237
import rerun as rr

b3d/modeling_utils.py

+54-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import jax
44
import jax.numpy as jnp
55
from tensorflow_probability.substrates import jax as tfp
6+
from genjax import Mask
67

78
uniform_discrete = genjax.exact_density(
89
lambda key, vals: jax.random.choice(key, vals),
@@ -35,4 +36,56 @@ def logpdf(v, *args, **kwargs):
3536
d = dist(*args, **kwargs)
3637
return jnp.sum(d.log_prob(v))
3738

38-
return genjax.exact_density(sampler, logpdf)
39+
return genjax.exact_density(sampler, logpdf)
40+
41+
normal = tfp_distribution(tfp.distributions.Normal)
42+
43+
def masked_scan_combinator(step, **scan_kwargs):
44+
"""
45+
Given a generative function `step` so that `step.scan(n=N)` is valid,
46+
return a generative function accepting an input
47+
`(initial_state, masked_input_values_array)` and returning a pair
48+
`(masked_final_state, masked_returnvalue_sequence)`.
49+
This operates similarly to `step.scan`, but the input values can be masked.
50+
"""
51+
mstep = step.mask().dimap(
52+
pre=lambda masked_state, masked_inval: (
53+
jnp.logical_and(masked_state.flag, masked_inval.flag),
54+
masked_state.value,
55+
masked_inval.value
56+
),
57+
post=lambda args, masked_retval: (
58+
Mask(masked_retval.flag, masked_retval.value[0]),
59+
Mask(masked_retval.flag, masked_retval.value[1])
60+
)
61+
)
62+
63+
# This should be given a pair (
64+
# Mask(True, initial_state),
65+
# Mask(bools_indicating_active, input_vals)
66+
# ).
67+
# It wll output a pair (masked_final_state, masked_returnvalue_sequence).
68+
scanned = mstep.scan(**scan_kwargs)
69+
70+
scanned_nice = scanned.dimap(
71+
pre=lambda initial_state, masked_input_values: (
72+
Mask(True, initial_state),
73+
Mask(masked_input_values.flag, masked_input_values.value)
74+
),
75+
post=lambda args, retval: retval
76+
)
77+
78+
return scanned_nice
79+
80+
def variable_length_unfold_combinator(step, **scan_kwargs):
81+
"""
82+
Step should accept one arg, `state`, as input,
83+
and should return a pair `(new_state, retval_for_this_timestep)`.
84+
"""
85+
scanned = masked_scan_combinator(step, **scan_kwargs)
86+
return scanned.dimap(
87+
pre=lambda initial_state, n_steps: (
88+
initial_state,
89+
Mask(jnp.array())
90+
)
91+
)

notebooks/integration.ipynb

+71-25
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)