Skip to content

Commit 1e520af

Browse files
committed
make pr
1 parent 0c30c19 commit 1e520af

37 files changed

+15
-1306649
lines changed

debug.ipynb

-445
This file was deleted.

physion_utils.py

-49
This file was deleted.

src/b3d/chisight/gen3d/inference/inference.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import partial
2+
13
import jax
24
import jax.numpy as jnp
35
import jax.random
@@ -17,6 +19,7 @@
1719
from .utils import logmeanexp, update_field
1820

1921

22+
@partial(jax.jit, static_argnames=("do_advance_time"))
2023
def inference_step(
2124
key,
2225
trace,
@@ -28,7 +31,7 @@ def inference_step(
2831
if do_advance_time:
2932
key, subkey = split(key)
3033
trace = advance_time(subkey, trace, observed_rgbd)
31-
34+
3235
@jax.jit
3336
def c2f_step(
3437
key,
@@ -41,13 +44,14 @@ def c2f_step(
4144

4245
# Propose the poses
4346
pose_generation_keys = split(k1, inference_hyperparams.n_poses)
44-
proposed_poses, log_q_poses = jax.vmap(propose_pose, in_axes=(0, None, None, None))(
45-
pose_generation_keys, trace, addr, pose_proposal_args
46-
)
47+
proposed_poses, log_q_poses = jax.vmap(
48+
propose_pose, in_axes=(0, None, None, None)
49+
)(pose_generation_keys, trace, addr, pose_proposal_args)
4750

4851
def update_and_get_scores(key, proposed_pose, trace, addr):
49-
updated_trace, score = my_func(key, proposed_pose, trace, addr)
50-
return updated_trace, score
52+
key, subkey = split(key)
53+
updated_trace = update_field(subkey, trace, addr, proposed_pose)
54+
return updated_trace, updated_trace.get_score()
5155

5256
param_generation_keys = split(k2, inference_hyperparams.n_poses)
5357
_, p_scores = jax.lax.map(
@@ -64,7 +68,10 @@ def update_and_get_scores(key, proposed_pose, trace, addr):
6468

6569
chosen_index = jax.random.categorical(k3, weights)
6670
resampled_trace, _ = update_and_get_scores(
67-
param_generation_keys[chosen_index], proposed_poses[chosen_index], trace, addr
71+
param_generation_keys[chosen_index],
72+
proposed_poses[chosen_index],
73+
trace,
74+
addr,
6875
)
6976
return (
7077
resampled_trace,
@@ -162,9 +169,3 @@ def propose_pose(key, advanced_trace, addr, args):
162169
pose = Pose.sample_gaussian_vmf_pose(key, previous_pose, std, conc)
163170
log_q = Pose.logpdf_gaussian_vmf_pose(pose, previous_pose, std, conc)
164171
return pose, log_q
165-
166-
167-
def my_func(key, pose, trace, addr):
168-
k1, _, _, _ = split(key, 4)
169-
updated_trace = update_field(k1, trace, addr, pose)
170-
return updated_trace, updated_trace.get_score()

src/b3d/chisight/gen3d/inference/utils.py

-2
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ def update_fields(key, trace, fieldnames, values):
2525
corresponding values in `values`. Returns a new trace.
2626
"""
2727
hyperparams, previous_state = trace.get_args()
28-
print("previous_state: ", previous_state)
29-
jax.debug.print("previous_state: {v}", v=previous_state)
3028
trace, _, _, _ = trace.update(
3129
key,
3230
U.g(

src/b3d/chisight/gen3d/settings.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import b3d.chisight.gen3d.transition_kernels as transition_kernels
22
from b3d.chisight.gen3d.hyperparams import InferenceHyperparams
33

4-
p_resample_color = 0.005
54
hyperparams = {
65
"pose_kernel": transition_kernels.GaussianVMFPoseDriftKernel(0.02, 1000.0),
76
"color_noise_variance": 1,

src/b3d/chisight/gen3d/transition_kernels.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def sample(self, key: PRNGKey, prev_pose):
3434
)
3535

3636
def logpdf(self, new_pose, prev_pose) -> ArrayLike:
37+
print(new_pose)
3738
return Pose.logpdf_gaussian_vmf_pose(
3839
new_pose, prev_pose, self.std, self.concentration
3940
)

0 commit comments

Comments
 (0)