@@ -9,21 +9,34 @@ class RewardErrorCell(JaxComponent): ## Reward prediction error cell
9
9
10
10
| --- Cell Input Compartments: ---
11
11
| reward - current reward signal at time `t`
12
+ | accum_reward - current accumulated episodic reward signal
12
13
| --- Cell Output Compartments: ---
13
14
| mu - current moving average prediction of reward at time `t`
15
+ | rpe - current reward prediction error (RPE) signal
16
+ | accum_reward - current accumulated episodic reward signal (IF online predictor not used)
14
17
15
18
Args:
16
19
name: the string name of this cell
17
20
18
21
n_units: number of cellular entities (neural population size)
19
22
20
23
alpha: decay factor to apply to (exponential) moving average prediction
24
+
25
+ ema_window_len: exponential moving average window length -- for use only
26
+ in `evolve` step for updating episodic reward signals; (default: 10)
27
+
28
+ use_online_predictor: use online prediction of reward signal (default: True)
29
+ -- if set to False, then reward prediction will only occur upon a call
30
+ to this cell's `evolve` function
21
31
"""
22
- def __init__ (self , name , n_units , alpha , batch_size = 1 , ** kwargs ):
32
+ def __init__ (self , name , n_units , alpha , ema_window_len = 10 ,
33
+ use_online_predictor = True , batch_size = 1 , ** kwargs ):
23
34
super ().__init__ (name , ** kwargs )
24
35
25
36
## RPE meta-parameters
26
37
self .alpha = alpha
38
+ self .ema_window_len = ema_window_len
39
+ self .use_online_predictor = use_online_predictor
27
40
28
41
## Layer Size Setup
29
42
self .n_units = n_units
@@ -34,29 +47,55 @@ def __init__(self, name, n_units, alpha, batch_size=1, **kwargs):
34
47
self .mu = Compartment (restVals ) ## reward predictor state(s)
35
48
self .reward = Compartment (restVals ) ## target reward signal(s)
36
49
self .rpe = Compartment (restVals ) ## reward prediction error(s)
50
+ self .accum_reward = Compartment (restVals ) ## accumulated reward signal(s)
51
+ self .n_ep_steps = Compartment (jnp .zeros ((self .batch_size , 1 ))) ## number of episode steps taken
37
52
38
53
@staticmethod
39
- def _advance_state (dt , alpha , mu , rpe , reward ):
54
+ def _advance_state (dt , use_online_predictor , alpha , mu , rpe , reward ,
55
+ n_ep_steps , accum_reward ):
40
56
## compute/update RPE and predictor values
57
+ accum_reward = accum_reward + reward
41
58
rpe = reward - mu
42
- mu = mu * (1. - alpha ) + reward * alpha
43
- return mu , rpe
59
+ if use_online_predictor :
60
+ mu = mu * (1. - alpha ) + reward * alpha
61
+ n_ep_steps = n_ep_steps + 1
62
+ return mu , rpe , n_ep_steps , accum_reward
44
63
45
64
@resolver (_advance_state )
46
- def advance_state (self , mu , rpe ):
65
+ def advance_state (self , mu , rpe , n_ep_steps , accum_reward ):
47
66
self .mu .set (mu )
48
67
self .rpe .set (rpe )
68
+ self .n_ep_steps .set (n_ep_steps )
69
+ self .accum_reward .set (accum_reward )
70
+
71
+ @staticmethod
72
+ def _evolve (dt , use_online_predictor , ema_window_len , n_ep_steps , mu ,
73
+ accum_reward ):
74
+ if use_online_predictor :
75
+ ## total episodic reward signal
76
+ r = accum_reward / n_ep_steps
77
+ mu = (1. - 1. / ema_window_len ) * mu + (1. / ema_window_len ) * r
78
+ return mu
79
+
80
+ @resolver (_evolve )
81
+ def evolve (self , mu ):
82
+ self .mu .set (mu )
49
83
50
84
@staticmethod
51
85
def _reset (batch_size , n_units ):
52
- mu = jnp .zeros ((batch_size , n_units )) #None
53
- rpe = jnp .zeros ((batch_size , n_units )) #None
54
- return mu , rpe
86
+ restVals = jnp .zeros ((batch_size , n_units ))
87
+ mu = restVals
88
+ rpe = restVals
89
+ accum_reward = restVals
90
+ n_ep_steps = jnp .zeros ((batch_size , 1 ))
91
+ return mu , rpe , accum_reward , n_ep_steps
55
92
56
93
@resolver (_reset )
57
- def reset (self , mu , rpe ):
94
+ def reset (self , mu , rpe , accum_reward , n_ep_steps ):
58
95
self .mu .set (mu )
59
96
self .rpe .set (rpe )
97
+ self .accum_reward .set (accum_reward )
98
+ self .n_ep_steps .set (n_ep_steps )
60
99
61
100
@classmethod
62
101
def help (cls ): ## component help function
@@ -69,16 +108,23 @@ def help(cls): ## component help function
69
108
{"reward" : "External reward signals/values" },
70
109
"outputs" :
71
110
{"mu" : "Current state of reward predictor" ,
72
- "rpe" : "Current value of reward prediction error at time `t`" },
111
+ "rpe" : "Current value of reward prediction error at time `t`" ,
112
+ "accum_reward" : "Current accumulated episodic reward signal (generally "
113
+ "produced at the end of a control episode of `n_steps`)" ,
114
+ "n_ep_steps" : "Number of episodic steps taken/tracked thus far "
115
+ "(since last `reset` call)" ,
116
+ "use_online_predictor" : "Should an online reward predictor be used/maintained?" },
73
117
}
74
118
hyperparams = {
75
119
"n_units" : "Number of neuronal cells to model in this layer" ,
76
120
"alpha" : "Moving average decay factor" ,
121
+ "ema_window_len" : "Exponential moving average window length" ,
77
122
"batch_size" : "Batch size dimension of this component"
78
123
}
79
124
info = {cls .__name__ : properties ,
80
125
"compartments" : compartment_props ,
81
- "dynamics" : "rpe = reward - mu; mu = mu * (1 - alpha) + reward * alpha" ,
126
+ "dynamics" : "rpe = reward - mu; mu = mu * (1 - alpha) + reward * alpha; "
127
+ "accum_reward = accum_reward + reward" ,
82
128
"hyperparameters" : hyperparams }
83
129
return info
84
130
0 commit comments