Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gm/dynamic length scan #49

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 70 additions & 33 deletions b3d/chisight/shared/particle_system.py
Original file line number Diff line number Diff line change
@@ -113,17 +113,24 @@ def particle_system_state_step(carried_state, _):

@gen
def latent_particle_model(
num_timesteps, # const object
max_num_timesteps, # const object
num_timesteps,
num_particles, # const object
num_clusters, # const object
relative_particle_poses_prior_params,
initial_object_poses_prior_params,
camera_pose_prior_params
):
"""
Retval is a dict with keys "relative_particle_poses", "absolute_particle_poses",
"object_poses", "camera_poses", "vis_mask"
Leading dimension for each timestep is the batch dimension.
The retval is a dict with keys "object_assignments" and "masked_dynamic_state".
The value at "masked_dynamic_state" is a genjax.Mask object `m`.
`m.value` is a dictionary with keys "relative_particle_poses", "absolute_particle_poses",
"object_poses", "camera_poses", "vis_mask".
The leading dimension for each will have size `max_num_timesteps`.
The boolean array `m.flag` will indicate which of these timesteps are valid
(and which are values >= `num_timesteps`).
The values at these invalid timesteps are undefined.
Using these values directly will cause silent errors.
"""
(state0, init_retval) = initial_particle_system_state(
num_particles, num_clusters,
@@ -132,32 +139,57 @@ def latent_particle_model(
camera_pose_prior_params
) @ "state0"

final_state, scan_retvals = particle_system_state_step.scan(n=(num_timesteps.const - 1))(state0, None) @ "states1+"
masked_final_state, masked_scan_retvals = b3d.modeling_utils.masked_scan_combinator(
particle_system_state_step,
n=(max_num_timesteps.const-1)
)(
state0,
genjax.Mask(
# This next line tells the scan combinator how many timesteps to run
jnp.arange(max_num_timesteps.const - 1) < num_timesteps - 1,
jnp.zeros(max_num_timesteps.const - 1)
)
) @ "states1+"


# concatenate each element of init_retval, scan_retvals
return jax.tree.map(
concatenated_states_possibly_invalid = jax.tree.map(
lambda t1, t2: jnp.concatenate([t1[None, :], t2], axis=0),
init_retval, scan_retvals
), final_state
init_retval, masked_scan_retvals.value
)
masked_concatenated_states = genjax.Mask(
jnp.concatenate([jnp.array([True]), masked_scan_retvals.flag]),
concatenated_states_possibly_invalid
)

object_assignments = state0[1][0]
latent_dynamics_summary = {
"object_assignments": object_assignments,
"masked_dynamic_state": masked_concatenated_states,
}

return latent_dynamics_summary

@genjax.gen
def sparse_observation_model(particle_absolute_poses, camera_pose, visibility, instrinsics, sigma):
# TODO: add visibility
uv = b3d.camera.screen_from_world(particle_absolute_poses.pos, camera_pose, instrinsics.const)
uv_ = genjax.normal(uv, jnp.tile(sigma, uv.shape)) @ "sensor_coordinates"
uv_ = b3d.modeling_utils.normal(uv, jnp.tile(sigma, uv.shape)) @ "sensor_coordinates"
return uv_

@genjax.gen
def sparse_gps_model(latent_particle_model_args, obs_model_args):
# (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1)
particle_dynamics_summary, final_state = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics"
obs = sparse_observation_model.vmap(in_axes=(0, 0, 0, None, None))(
particle_dynamics_summary["absolute_particle_poses"],
particle_dynamics_summary["camera_pose"],
particle_dynamics_summary["vis_mask"],
latent_dynamics_summary = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics"
masked_particle_dynamics_summary = latent_dynamics_summary["masked_dynamic_state"]
_UNSAFE_particle_dynamics_summary = masked_particle_dynamics_summary.value
masked_obs = sparse_observation_model.mask().vmap(in_axes=(0, 0, 0, 0, None, None))(
masked_particle_dynamics_summary.flag,
_UNSAFE_particle_dynamics_summary["absolute_particle_poses"],
_UNSAFE_particle_dynamics_summary["camera_pose"],
_UNSAFE_particle_dynamics_summary["vis_mask"],
*obs_model_args
) @ "obs"
return (particle_dynamics_summary, final_state, obs)
return (latent_dynamics_summary, masked_obs)



@@ -166,26 +198,28 @@ def make_dense_gps_model(renderer):

@genjax.gen
def dense_gps_model(latent_particle_model_args, dense_likelihood_args):
# (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1)
particle_dynamics_summary, final_state = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics"
absolute_particle_poses_last_frame = particle_dynamics_summary["absolute_particle_poses"][-1]
camera_pose_last_frame = particle_dynamics_summary["camera_pose"][-1]
latent_dynamics_summary = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics"
masked_particle_dynamics_summary = latent_dynamics_summary["masked_dynamic_state"]
_UNSAFE_particle_dynamics_summary = masked_particle_dynamics_summary.value

last_timestep_index = jnp.sum(masked_particle_dynamics_summary.flag) - 1
absolute_particle_poses_last_frame = _UNSAFE_particle_dynamics_summary["absolute_particle_poses"][last_timestep_index]
camera_pose_last_frame = _UNSAFE_particle_dynamics_summary["camera_pose"][last_timestep_index]
absolute_particle_poses_in_camera_frame = camera_pose_last_frame.inv() @ absolute_particle_poses_last_frame

(meshes, likelihood_args) = dense_likelihood_args
merged_mesh = Mesh.transform_and_merge_meshes(meshes, absolute_particle_poses_in_camera_frame)
image = dense_observation_model(merged_mesh, likelihood_args) @ "obs"
return (particle_dynamics_summary, final_state, image)
return (latent_dynamics_summary, image)

return dense_gps_model


def visualize_particle_system(latent_particle_model_args, particle_dynamics_summary, final_state):
def visualize_particle_system(latent_particle_model_args, latent_dynamics_summary):
import rerun as rr
(dynamic_state, static_state) = final_state

(
num_timesteps, # const object
max_num_timesteps, # const object
num_timesteps,
num_particles, # const object
num_clusters, # const object
relative_particle_poses_prior_params,
@@ -194,17 +228,20 @@ def visualize_particle_system(latent_particle_model_args, particle_dynamics_summ
) = latent_particle_model_args

colors = b3d.distinct_colors(num_clusters.const)
absolute_particle_poses = particle_dynamics_summary["absolute_particle_poses"]
object_poses = particle_dynamics_summary["object_poses"]
camera_pose = particle_dynamics_summary["camera_pose"]
object_assignments = static_state[0]

masked_particle_dynamics_summary = latent_dynamics_summary["masked_dynamic_state"]
object_assignments = latent_dynamics_summary["object_assignments"]
_UNSAFE_absolute_particle_poses = masked_particle_dynamics_summary.value["absolute_particle_poses"]
_UNSAFE_object_poses = masked_particle_dynamics_summary.value["object_poses"]
_UNSAFE_camera_pose = masked_particle_dynamics_summary.value["camera_pose"]

cluster_colors = jnp.array(b3d.distinct_colors(num_clusters.const))

for t in range(num_timesteps.const):
for t in range(num_timesteps):
rr.set_time_sequence("time", t)
assert masked_particle_dynamics_summary.flag[t], "Erroring before attempting to unmask invalid masked data."

cam_pose = camera_pose[t]
cam_pose = _UNSAFE_camera_pose[t]
rr.log(
f"/camera",
rr.Transform3D(translation=cam_pose.position, rotation=rr.Quaternion(xyzw=cam_pose.xyzw)),
@@ -220,10 +257,10 @@ def visualize_particle_system(latent_particle_model_args, particle_dynamics_summ
rr.log(
"absolute_particle_poses",
rr.Points3D(
absolute_particle_poses[t].pos,
_UNSAFE_absolute_particle_poses[t].pos,
colors=cluster_colors[object_assignments]
)
)

for i in range(num_clusters.const):
b3d.rr_log_pose(f"cluster/{i}", object_poses[t][i])
b3d.rr_log_pose(f"cluster/{i}", _UNSAFE_object_poses[t][i])
55 changes: 54 additions & 1 deletion b3d/modeling_utils.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
from genjax import Mask

uniform_discrete = genjax.exact_density(
lambda key, vals: jax.random.choice(key, vals),
@@ -35,4 +36,56 @@ def logpdf(v, *args, **kwargs):
d = dist(*args, **kwargs)
return jnp.sum(d.log_prob(v))

return genjax.exact_density(sampler, logpdf)
return genjax.exact_density(sampler, logpdf)

normal = tfp_distribution(tfp.distributions.Normal)

def masked_scan_combinator(step, **scan_kwargs):
"""
Given a generative function `step` so that `step.scan(n=N)` is valid,
return a generative function accepting an input
`(initial_state, masked_input_values_array)` and returning a pair
`(masked_final_state, masked_returnvalue_sequence)`.
This operates similarly to `step.scan`, but the input values can be masked.
"""
mstep = step.mask().dimap(
pre=lambda masked_state, masked_inval: (
jnp.logical_and(masked_state.flag, masked_inval.flag),
masked_state.value,
masked_inval.value
),
post=lambda args, masked_retval: (
Mask(masked_retval.flag, masked_retval.value[0]),
Mask(masked_retval.flag, masked_retval.value[1])
)
)

# This should be given a pair (
# Mask(True, initial_state),
# Mask(bools_indicating_active, input_vals)
# ).
# It wll output a pair (masked_final_state, masked_returnvalue_sequence).
scanned = mstep.scan(**scan_kwargs)

scanned_nice = scanned.dimap(
pre=lambda initial_state, masked_input_values: (
Mask(True, initial_state),
Mask(masked_input_values.flag, masked_input_values.value)
),
post=lambda args, retval: retval
)

return scanned_nice

def variable_length_unfold_combinator(step, **scan_kwargs):
"""
Step should accept one arg, `state`, as input,
and should return a pair `(new_state, retval_for_this_timestep)`.
"""
scanned = masked_scan_combinator(step, **scan_kwargs)
return scanned.dimap(
pre=lambda initial_state, n_steps: (
initial_state,
Mask(jnp.array())
)
)
96 changes: 71 additions & 25 deletions notebooks/integration.ipynb

Large diffs are not rendered by default.

11 changes: 6 additions & 5 deletions tests/test_chisight_sparse_gps.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,8 @@ def test_sparse_gps_simulate():
renderer = RendererOriginal()
key = jax.random.PRNGKey(100)

num_timesteps = Pytree.const(10)
max_num_timesteps = Pytree.const(20)
num_timesteps = 10
num_particles = Pytree.const(100)
num_clusters = Pytree.const(4)
relative_particle_poses_prior_params = (Pose.identity(), .5, 0.25)
@@ -32,7 +33,8 @@ def test_sparse_gps_simulate():

trace = ps.sparse_gps_model.simulate(key, (
(
num_timesteps, # const object
max_num_timesteps, # const object
num_timesteps,
num_particles, # const object
num_clusters, # const object
relative_particle_poses_prior_params,
@@ -43,11 +45,10 @@ def test_sparse_gps_simulate():
))


particle_dynamics_summary = trace.get_retval()[0]
final_state = trace.get_retval()[1]
particle_system_summary = trace.get_retval()[0]
latent_particle_model_args = trace.get_args()[0]

# import importlib
# importlib.reload(b3d.chisight.shared.particle_system)

ps.visualize_particle_system(latent_particle_model_args, particle_dynamics_summary, final_state)
ps.visualize_particle_system(latent_particle_model_args, particle_system_summary)