1
+ from functools import partial
2
+
1
3
import jax
2
4
import jax .numpy as jnp
3
5
import jax .random
17
19
from .utils import logmeanexp , update_field
18
20
19
21
22
+ @partial (jax .jit , static_argnames = ("do_advance_time" ))
20
23
def inference_step (
21
24
key ,
22
25
trace ,
@@ -28,7 +31,7 @@ def inference_step(
28
31
if do_advance_time :
29
32
key , subkey = split (key )
30
33
trace = advance_time (subkey , trace , observed_rgbd )
31
-
34
+
32
35
@jax .jit
33
36
def c2f_step (
34
37
key ,
@@ -41,13 +44,14 @@ def c2f_step(
41
44
42
45
# Propose the poses
43
46
pose_generation_keys = split (k1 , inference_hyperparams .n_poses )
44
- proposed_poses , log_q_poses = jax .vmap (propose_pose , in_axes = ( 0 , None , None , None ))(
45
- pose_generation_keys , trace , addr , pose_proposal_args
46
- )
47
+ proposed_poses , log_q_poses = jax .vmap (
48
+ propose_pose , in_axes = ( 0 , None , None , None )
49
+ )( pose_generation_keys , trace , addr , pose_proposal_args )
47
50
48
51
def update_and_get_scores (key , proposed_pose , trace , addr ):
49
- updated_trace , score = my_func (key , proposed_pose , trace , addr )
50
- return updated_trace , score
52
+ key , subkey = split (key )
53
+ updated_trace = update_field (subkey , trace , addr , proposed_pose )
54
+ return updated_trace , updated_trace .get_score ()
51
55
52
56
param_generation_keys = split (k2 , inference_hyperparams .n_poses )
53
57
_ , p_scores = jax .lax .map (
@@ -64,7 +68,10 @@ def update_and_get_scores(key, proposed_pose, trace, addr):
64
68
65
69
chosen_index = jax .random .categorical (k3 , weights )
66
70
resampled_trace , _ = update_and_get_scores (
67
- param_generation_keys [chosen_index ], proposed_poses [chosen_index ], trace , addr
71
+ param_generation_keys [chosen_index ],
72
+ proposed_poses [chosen_index ],
73
+ trace ,
74
+ addr ,
68
75
)
69
76
return (
70
77
resampled_trace ,
@@ -162,9 +169,3 @@ def propose_pose(key, advanced_trace, addr, args):
162
169
pose = Pose .sample_gaussian_vmf_pose (key , previous_pose , std , conc )
163
170
log_q = Pose .logpdf_gaussian_vmf_pose (pose , previous_pose , std , conc )
164
171
return pose , log_q
165
-
166
-
167
- def my_func (key , pose , trace , addr ):
168
- k1 , _ , _ , _ = split (key , 4 )
169
- updated_trace = update_field (k1 , trace , addr , pose )
170
- return updated_trace , updated_trace .get_score ()
0 commit comments