Skip to content

Commit e6d7317

Browse files
committed
cleaned up rpe cell -- online versus long-term predictor compartments
1 parent f00c63e commit e6d7317

File tree

1 file changed

+57
-11
lines changed

1 file changed

+57
-11
lines changed

ngclearn/components/neurons/graded/rewardErrorCell.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,34 @@ class RewardErrorCell(JaxComponent): ## Reward prediction error cell
99
1010
| --- Cell Input Compartments: ---
1111
| reward - current reward signal at time `t`
12+
| accum_reward - current accumulated episodic reward signal
1213
| --- Cell Output Compartments: ---
1314
| mu - current moving average prediction of reward at time `t`
15+
| rpe - current reward prediction error (RPE) signal
16+
| accum_reward - current accumulated episodic reward signal (IF online predictor not used)
1417
1518
Args:
1619
name: the string name of this cell
1720
1821
n_units: number of cellular entities (neural population size)
1922
2023
alpha: decay factor to apply to (exponential) moving average prediction
24+
25+
ema_window_len: exponential moving average window length -- for use only
26+
in `evolve` step for updating episodic reward signals; (default: 10)
27+
28+
use_online_predictor: use online prediction of reward signal (default: True)
29+
-- if set to False, then reward prediction will only occur upon a call
30+
to this cell's `evolve` function
2131
"""
22-
def __init__(self, name, n_units, alpha, batch_size=1, **kwargs):
32+
def __init__(self, name, n_units, alpha, ema_window_len=10,
33+
use_online_predictor=True, batch_size=1, **kwargs):
2334
super().__init__(name, **kwargs)
2435

2536
## RPE meta-parameters
2637
self.alpha = alpha
38+
self.ema_window_len = ema_window_len
39+
self.use_online_predictor = use_online_predictor
2740

2841
## Layer Size Setup
2942
self.n_units = n_units
@@ -34,29 +47,55 @@ def __init__(self, name, n_units, alpha, batch_size=1, **kwargs):
3447
self.mu = Compartment(restVals) ## reward predictor state(s)
3548
self.reward = Compartment(restVals) ## target reward signal(s)
3649
self.rpe = Compartment(restVals) ## reward prediction error(s)
50+
self.accum_reward = Compartment(restVals) ## accumulated reward signal(s)
51+
self.n_ep_steps = Compartment(jnp.zeros((self.batch_size, 1))) ## number of episode steps taken
3752

3853
@staticmethod
39-
def _advance_state(dt, alpha, mu, rpe, reward):
54+
def _advance_state(dt, use_online_predictor, alpha, mu, rpe, reward,
55+
n_ep_steps, accum_reward):
4056
## compute/update RPE and predictor values
57+
accum_reward = accum_reward + reward
4158
rpe = reward - mu
42-
mu = mu * (1. - alpha) + reward * alpha
43-
return mu, rpe
59+
if use_online_predictor:
60+
mu = mu * (1. - alpha) + reward * alpha
61+
n_ep_steps = n_ep_steps + 1
62+
return mu, rpe, n_ep_steps, accum_reward
4463

4564
@resolver(_advance_state)
46-
def advance_state(self, mu, rpe):
65+
def advance_state(self, mu, rpe, n_ep_steps, accum_reward):
4766
self.mu.set(mu)
4867
self.rpe.set(rpe)
68+
self.n_ep_steps.set(n_ep_steps)
69+
self.accum_reward.set(accum_reward)
70+
71+
@staticmethod
72+
def _evolve(dt, use_online_predictor, ema_window_len, n_ep_steps, mu,
73+
accum_reward):
74+
if use_online_predictor:
75+
## total episodic reward signal
76+
r = accum_reward/n_ep_steps
77+
mu = (1. - 1./ema_window_len) * mu + (1./ema_window_len) * r
78+
return mu
79+
80+
@resolver(_evolve)
81+
def evolve(self, mu):
82+
self.mu.set(mu)
4983

5084
@staticmethod
5185
def _reset(batch_size, n_units):
52-
mu = jnp.zeros((batch_size, n_units)) #None
53-
rpe = jnp.zeros((batch_size, n_units)) #None
54-
return mu, rpe
86+
restVals = jnp.zeros((batch_size, n_units))
87+
mu = restVals
88+
rpe = restVals
89+
accum_reward = restVals
90+
n_ep_steps = jnp.zeros((batch_size, 1))
91+
return mu, rpe, accum_reward, n_ep_steps
5592

5693
@resolver(_reset)
57-
def reset(self, mu, rpe):
94+
def reset(self, mu, rpe, accum_reward, n_ep_steps):
5895
self.mu.set(mu)
5996
self.rpe.set(rpe)
97+
self.accum_reward.set(accum_reward)
98+
self.n_ep_steps.set(n_ep_steps)
6099

61100
@classmethod
62101
def help(cls): ## component help function
@@ -69,16 +108,23 @@ def help(cls): ## component help function
69108
{"reward": "External reward signals/values"},
70109
"outputs":
71110
{"mu": "Current state of reward predictor",
72-
"rpe": "Current value of reward prediction error at time `t`"},
111+
"rpe": "Current value of reward prediction error at time `t`",
112+
"accum_reward": "Current accumulated episodic reward signal (generally "
113+
"produced at the end of a control episode of `n_steps`)",
114+
"n_ep_steps": "Number of episodic steps taken/tracked thus far "
115+
"(since last `reset` call)",
116+
"use_online_predictor": "Should an online reward predictor be used/maintained?"},
73117
}
74118
hyperparams = {
75119
"n_units": "Number of neuronal cells to model in this layer",
76120
"alpha": "Moving average decay factor",
121+
"ema_window_len": "Exponential moving average window length",
77122
"batch_size": "Batch size dimension of this component"
78123
}
79124
info = {cls.__name__: properties,
80125
"compartments": compartment_props,
81-
"dynamics": "rpe = reward - mu; mu = mu * (1 - alpha) + reward * alpha",
126+
"dynamics": "rpe = reward - mu; mu = mu * (1 - alpha) + reward * alpha; "
127+
"accum_reward = accum_reward + reward",
82128
"hyperparameters": hyperparams}
83129
return info
84130

0 commit comments

Comments
 (0)