Skip to content

Commit 5a07ed3

Browse files
committed
variable length unfold model
1 parent 6911ee6 commit 5a07ed3

File tree

3 files changed

+163
-44
lines changed

3 files changed

+163
-44
lines changed

b3d/chisight/shared/particle_system.py

+38-18
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from b3d.chisight.dense.dense_likelihood import make_dense_observation_model, DenseImageLikelihoodArgs
99
from b3d import Pose, Mesh
1010
from b3d.chisight.sparse.gps_utils import add_dummy_var
11-
from b3d.chisight.sparse.pose_utils import uniform_pose_in_ball
11+
from b3d.pose.pose_utils import uniform_pose_in_ball
1212
dummy_mapped_uniform_pose = add_dummy_var(uniform_pose_in_ball).vmap(in_axes=(0,None,None,None))
1313

1414

@@ -113,7 +113,8 @@ def particle_system_state_step(carried_state, _):
113113

114114
@gen
115115
def latent_particle_model(
116-
num_timesteps, # const object
116+
max_num_timesteps, # const object
117+
num_timesteps,
117118
num_particles, # const object
118119
num_clusters, # const object
119120
relative_particle_poses_prior_params,
@@ -132,32 +133,49 @@ def latent_particle_model(
132133
camera_pose_prior_params
133134
) @ "state0"
134135

135-
final_state, scan_retvals = particle_system_state_step.scan(n=(num_timesteps.const - 1))(state0, None) @ "states1+"
136+
masked_final_state, masked_scan_retvals = b3d.modeling_utils.masked_scan_combinator(
137+
particle_system_state_step,
138+
n=(max_num_timesteps.const-1)
139+
)(
140+
state0,
141+
genjax.Mask(
142+
# This next line tells the scan combinator how many timesteps to run
143+
jnp.arange(max_num_timesteps.const - 1) < num_timesteps - 1,
144+
jnp.zeros(max_num_timesteps.const - 1)
145+
)
146+
) @ "states1+"
147+
136148

137149
# concatenate each element of init_retval, scan_retvals
138-
return jax.tree.map(
150+
concatenated_states_possibly_invalid = jax.tree.map(
139151
lambda t1, t2: jnp.concatenate([t1[None, :], t2], axis=0),
140-
init_retval, scan_retvals
152+
init_retval, masked_scan_retvals.value
153+
)
154+
masked_concatenated_states = genjax.Mask(
155+
jnp.concatenate([jnp.array([True]), masked_scan_retvals.flag]),
156+
concatenated_states_possibly_invalid
141157
)
158+
return masked_concatenated_states
142159

143160
@genjax.gen
144161
def sparse_observation_model(particle_absolute_poses, camera_pose, visibility, instrinsics, sigma):
145162
# TODO: add visibility
146163
uv = b3d.camera.screen_from_world(particle_absolute_poses.pos, camera_pose, instrinsics.const)
147-
uv_ = genjax.normal(uv, jnp.tile(sigma, uv.shape)) @ "sensor_coordinates"
164+
uv_ = b3d.modeling_utils.normal(uv, jnp.tile(sigma, uv.shape)) @ "sensor_coordinates"
148165
return uv_
149166

150167
@genjax.gen
151168
def sparse_gps_model(latent_particle_model_args, obs_model_args):
152-
# (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1)
153-
particle_dynamics_summary = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics"
154-
obs = sparse_observation_model.vmap(in_axes=(0, 0, 0, None, None))(
155-
particle_dynamics_summary["absolute_particle_poses"],
156-
particle_dynamics_summary["camera_pose"],
157-
particle_dynamics_summary["vis_mask"],
169+
masked_particle_dynamics_summary = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics"
170+
_UNSAFE_particle_dynamics_summary = masked_particle_dynamics_summary.value
171+
masked_obs = sparse_observation_model.mask().vmap(in_axes=(0, 0, 0, 0, None, None))(
172+
masked_particle_dynamics_summary.flag,
173+
_UNSAFE_particle_dynamics_summary["absolute_particle_poses"],
174+
_UNSAFE_particle_dynamics_summary["camera_pose"],
175+
_UNSAFE_particle_dynamics_summary["vis_mask"],
158176
*obs_model_args
159177
) @ "observation"
160-
return (particle_dynamics_summary, obs)
178+
return (masked_particle_dynamics_summary, masked_obs)
161179

162180

163181

@@ -166,15 +184,17 @@ def make_dense_gps_model(renderer):
166184

167185
@genjax.gen
168186
def dense_gps_model(latent_particle_model_args, dense_likelihood_args):
169-
# (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1)
170-
particle_dynamics_summary = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics"
171-
absolute_particle_poses_last_frame = particle_dynamics_summary["absolute_particle_poses"][-1]
172-
camera_pose_last_frame = particle_dynamics_summary["camera_pose"][-1]
187+
masked_particle_dynamics_summary = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics"
188+
_UNSAFE_particle_dynamics_summary = masked_particle_dynamics_summary.value
189+
190+
last_timestep_index = jnp.sum(masked_particle_dynamics_summary.flag) - 1
191+
absolute_particle_poses_last_frame = _UNSAFE_particle_dynamics_summary["absolute_particle_poses"][last_timestep_index]
192+
camera_pose_last_frame = _UNSAFE_particle_dynamics_summary["camera_pose"][last_timestep_index]
173193
absolute_particle_poses_in_camera_frame = camera_pose_last_frame.inv() @ absolute_particle_poses_last_frame
174194

175195
(meshes, likelihood_args) = dense_likelihood_args
176196
merged_mesh = Mesh.transform_and_merge_meshes(meshes, absolute_particle_poses_in_camera_frame)
177197
image = dense_observation_model(merged_mesh, likelihood_args) @ "observation"
178-
return (particle_dynamics_summary, image)
198+
return (masked_particle_dynamics_summary, image)
179199

180200
return dense_gps_model

b3d/modeling_utils.py

+54-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import jax
44
import jax.numpy as jnp
55
from tensorflow_probability.substrates import jax as tfp
6+
from genjax import Mask
67

78
uniform_discrete = genjax.exact_density(
89
lambda key, vals: jax.random.choice(key, vals),
@@ -35,4 +36,56 @@ def logpdf(v, *args, **kwargs):
3536
d = dist(*args, **kwargs)
3637
return jnp.sum(d.log_prob(v))
3738

38-
return genjax.exact_density(sampler, logpdf)
39+
return genjax.exact_density(sampler, logpdf)
40+
41+
normal = tfp_distribution(tfp.distributions.Normal)
42+
43+
def masked_scan_combinator(step, **scan_kwargs):
44+
"""
45+
Given a generative function `step` so that `step.scan(n=N)` is valid,
46+
return a generative function accepting an input
47+
`(initial_state, masked_input_values_array)` and returning a pair
48+
`(masked_final_state, masked_returnvalue_sequence)`.
49+
This operates similarly to `step.scan`, but the input values can be masked.
50+
"""
51+
mstep = step.mask().dimap(
52+
pre=lambda masked_state, masked_inval: (
53+
jnp.logical_and(masked_state.flag, masked_inval.flag),
54+
masked_state.value,
55+
masked_inval.value
56+
),
57+
post=lambda args, masked_retval: (
58+
Mask(masked_retval.flag, masked_retval.value[0]),
59+
Mask(masked_retval.flag, masked_retval.value[1])
60+
)
61+
)
62+
63+
# This should be given a pair (
64+
# Mask(True, initial_state),
65+
# Mask(bools_indicating_active, input_vals)
66+
# ).
67+
# It wll output a pair (masked_final_state, masked_returnvalue_sequence).
68+
scanned = mstep.scan(**scan_kwargs)
69+
70+
scanned_nice = scanned.dimap(
71+
pre=lambda initial_state, masked_input_values: (
72+
Mask(True, initial_state),
73+
Mask(masked_input_values.flag, masked_input_values.value)
74+
),
75+
post=lambda args, retval: retval
76+
)
77+
78+
return scanned_nice
79+
80+
def variable_length_unfold_combinator(step, **scan_kwargs):
81+
"""
82+
Step should accept one arg, `state`, as input,
83+
and should return a pair `(new_state, retval_for_this_timestep)`.
84+
"""
85+
scanned = masked_scan_combinator(step, **scan_kwargs)
86+
return scanned.dimap(
87+
pre=lambda initial_state, n_steps: (
88+
initial_state,
89+
Mask(jnp.array())
90+
)
91+
)

notebooks/integration.ipynb

+71-25
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)