Skip to content

Commit 8d74157

Browse files
authored
Patched synapses added (#68)
* Patched synapses added * Update __init__.py * Update patch_utils.py patch_with_stride & patch_with_overlap functions + Create_Patches class added * Update patchedSynapse.py * Update hebbianPatchedSynapse.py * Update synapse_plot.py order added
1 parent 42167cb commit 8d74157

File tree

6 files changed

+639
-2
lines changed

6 files changed

+639
-2
lines changed

ngclearn/components/synapses/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,7 @@
2020
from .convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse
2121
## modulated synaptic components
2222
from .modulated.MSTDPETSynapse import MSTDPETSynapse
23+
## patched synaptic components
24+
from .patched.patchedSynapse import PatchedSynapse
25+
from .patched.staticPatchedSynapse import StaticPatchedSynapse
26+
from .patched.hebbianPatchedSynapse import HebbianPatchedSynapse
Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
import matplotlib.pyplot as plt
2+
from jax import random, numpy as jnp, jit
3+
from functools import partial
4+
from ngclearn.utils.optim import get_opt_init_fn, get_opt_step_fn
5+
from ngclearn import resolver, Component, Compartment
6+
from ngclearn.components.synapses import PatchedSynapse
7+
from ngclearn.utils import tensorstats
8+
9+
@partial(jit, static_argnums=[3, 4, 5, 6, 7, 8, 9])
10+
def _calc_update(pre, post, W, w_mask, w_bound, is_nonnegative=True, signVal=1., w_decay=0.,
11+
pre_wght=1., post_wght=1.):
12+
"""
13+
Compute a tensor of adjustments to be applied to a synaptic value matrix.
14+
15+
Args:
16+
pre: pre-synaptic statistic to drive Hebbian update
17+
18+
post: post-synaptic statistic to drive Hebbian update
19+
20+
W: synaptic weight values (at time t)
21+
22+
w_bound: maximum value to enforce over newly computed efficacies
23+
24+
is_nonnegative: (Unused)
25+
26+
signVal: multiplicative factor to modulate final update by (good for
27+
flipping the signs of a computed synaptic change matrix)
28+
29+
w_decay: synaptic decay factor to apply to this update
30+
31+
pre_wght: pre-synaptic weighting term (Default: 1.)
32+
33+
post_wght: post-synaptic weighting term (Default: 1.)
34+
35+
Returns:
36+
an update/adjustment matrix, an update adjustment vector (for biases)
37+
"""
38+
_pre = pre * pre_wght
39+
_post = post * post_wght
40+
dW = jnp.matmul(_pre.T, _post)
41+
db = jnp.sum(_post, axis=0, keepdims=True)
42+
if w_bound > 0.:
43+
dW = dW * (w_bound - jnp.abs(W))
44+
if w_decay > 0.:
45+
dW = dW - W * w_decay
46+
47+
if w_mask!=None:
48+
dW = dW * w_mask
49+
50+
return dW * signVal, db * signVal
51+
52+
@partial(jit, static_argnums=[1,2, 3])
53+
def _enforce_constraints(W, w_mask, w_bound, is_nonnegative=True):
54+
"""
55+
Enforces constraints that the (synaptic) efficacies/values within matrix
56+
`W` must adhere to.
57+
58+
Args:
59+
W: synaptic weight values (at time t)
60+
61+
w_bound: maximum value to enforce over newly computed efficacies
62+
63+
is_nonnegative: ensure updated value matrix is strictly non-negative
64+
65+
Returns:
66+
the newly evolved synaptic weight value matrix
67+
"""
68+
_W = W
69+
if w_bound > 0.:
70+
if is_nonnegative == True:
71+
_W = jnp.clip(_W, 0., w_bound)
72+
else:
73+
_W = jnp.clip(_W, -w_bound, w_bound)
74+
75+
if w_mask!=None:
76+
_W = _W * w_mask
77+
78+
return _W
79+
80+
class HebbianPatchedSynapse(PatchedSynapse):
81+
"""
82+
A synaptic cable that adjusts its efficacies via a two-factor Hebbian
83+
adjustment rule.
84+
85+
| --- Synapse Compartments: ---
86+
| inputs - input (takes in external signals)
87+
| outputs - output signals (transformation induced by synapses)
88+
| weights - current value matrix of synaptic efficacies
89+
| biases - current value vector of synaptic bias values
90+
| key - JAX PRNG key
91+
| --- Synaptic Plasticity Compartments: ---
92+
| pre - pre-synaptic signal to drive first term of Hebbian update (takes in external signals)
93+
| post - post-synaptic signal to drive 2nd term of Hebbian update (takes in external signals)
94+
| dWweights - current delta matrix containing changes to be applied to synaptic efficacies
95+
| dBiases - current delta vector containing changes to be applied to bias values
96+
| opt_params - locally-embedded optimizer statisticis (e.g., Adam 1st/2nd moments if adam is used)
97+
98+
Args:
99+
name: the string name of this cell
100+
101+
shape: tuple specifying shape of this synaptic cable (usually a 2-tuple
102+
with number of inputs by number of outputs)
103+
104+
eta: global learning rate
105+
106+
weight_init: a kernel to drive initialization of this synaptic cable's values;
107+
typically a tuple with 1st element as a string calling the name of
108+
initialization to use
109+
110+
bias_init: a kernel to drive initialization of biases for this synaptic cable
111+
(Default: None, which turns off/disables biases)
112+
113+
w_bound: maximum weight to softly bound this cable's value matrix to; if
114+
set to 0, then no synaptic value bounding will be applied
115+
116+
is_nonnegative: enforce that synaptic efficacies are always non-negative
117+
after each synaptic update (if False, no constraint will be applied)
118+
119+
w_decay: degree to which (L2) synaptic weight decay is applied to the
120+
computed Hebbian adjustment (Default: 0); note that decay is not
121+
applied to any configured biases
122+
123+
sign_value: multiplicative factor to apply to final synaptic update before
124+
it is applied to synapses; this is useful if gradient descent style
125+
optimization is required (as Hebbian rules typically yield
126+
adjustments for ascent)
127+
128+
optim_type: optimization scheme to physically alter synaptic values
129+
once an update is computed (Default: "sgd"); supported schemes
130+
include "sgd" and "adam"
131+
132+
:Note: technically, if "sgd" or "adam" is used but `signVal = 1`,
133+
then the ascent form of each rule is employed (signVal = -1) or
134+
a negative learning rate will mean a descent form of the
135+
`optim_scheme` is being employed
136+
137+
pre_wght: pre-synaptic weighting factor (Default: 1.)
138+
139+
post_wght: post-synaptic weighting factor (Default: 1.)
140+
141+
resist_scale: a fixed scaling factor to apply to synaptic transform
142+
(Default: 1.), i.e., yields: out = ((W * Rscale) * in) + b
143+
144+
p_conn: probability of a connection existing (default: 1.); setting
145+
this to < 1. will result in a sparser synaptic structure
146+
"""
147+
148+
def __init__(self, name, shape, n_sub_models, stride_shape=(0,0), eta=0., weight_init=None, bias_init=None,
149+
w_mask=None, w_bound=1., is_nonnegative=False, w_decay=0., sign_value=1.,
150+
optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1.,
151+
resist_scale=1., batch_size=1, **kwargs):
152+
super().__init__(name, shape, n_sub_models, stride_shape, w_mask, weight_init, bias_init, resist_scale,
153+
p_conn, batch_size=batch_size, **kwargs)
154+
155+
self.n_sub_models = n_sub_models
156+
self.sub_stride = stride_shape
157+
158+
self.shape = (shape[0] + (2 * stride_shape[0]),
159+
shape[1] + (2 * stride_shape[1]))
160+
self.sub_shape = (shape[0]//n_sub_models + (2 * stride_shape[0]),
161+
shape[1]//n_sub_models + (2* stride_shape[1]))
162+
163+
## synaptic plasticity properties and characteristics
164+
self.Rscale = resist_scale
165+
self.w_bound = w_bound
166+
self.w_decay = w_decay ## synaptic decay
167+
self.pre_wght = pre_wght
168+
self.post_wght = post_wght
169+
self.eta = eta
170+
self.is_nonnegative = is_nonnegative
171+
self.sign_value = sign_value
172+
173+
## optimization / adjustment properties (given learning dynamics above)
174+
self.opt = get_opt_step_fn(optim_type, eta=self.eta)
175+
176+
# compartments (state of the cell, parameters, will be updated through stateless calls)
177+
self.preVals = jnp.zeros((self.batch_size, self.shape[0]))
178+
self.postVals = jnp.zeros((self.batch_size, self.shape[1]))
179+
self.pre = Compartment(self.preVals)
180+
self.post = Compartment(self.postVals)
181+
self.w_mask = w_mask
182+
self.dWeights = Compartment(jnp.zeros(self.shape))
183+
self.dBiases = Compartment(jnp.zeros(self.shape[1]))
184+
185+
#key, subkey = random.split(self.key.value)
186+
self.opt_params = Compartment(get_opt_init_fn(optim_type)(
187+
[self.weights.value, self.biases.value]
188+
if bias_init else [self.weights.value]))
189+
190+
@staticmethod
191+
def _compute_update(w_mask, w_bound, is_nonnegative, sign_value, w_decay, pre_wght,
192+
post_wght, pre, post, weights):
193+
## calculate synaptic update values
194+
dW, db = _calc_update(
195+
pre, post, weights, w_mask, w_bound, is_nonnegative=is_nonnegative,
196+
signVal=sign_value, w_decay=w_decay, pre_wght=pre_wght,
197+
post_wght=post_wght)
198+
199+
return dW * jnp.where(0 != jnp.abs(weights), 1, 0) , db
200+
201+
@staticmethod
202+
def _evolve(w_mask, opt, w_bound, is_nonnegative, sign_value, w_decay, pre_wght,
203+
post_wght, bias_init, pre, post, weights, biases, opt_params):
204+
## calculate synaptic update values
205+
dWeights, dBiases = HebbianPatchedSynapse._compute_update(
206+
w_mask, w_bound, is_nonnegative, sign_value, w_decay,
207+
pre_wght, post_wght, pre, post, weights
208+
)
209+
## conduct a step of optimization - get newly evolved synaptic weight value matrix
210+
if bias_init != None:
211+
opt_params, [weights, biases] = opt(opt_params, [weights, biases], [dWeights, dBiases])
212+
else:
213+
# ignore db since no biases configured
214+
opt_params, [weights] = opt(opt_params, [weights], [dWeights])
215+
## ensure synaptic efficacies adhere to constraints
216+
weights = _enforce_constraints(weights, w_mask, w_bound, is_nonnegative=is_nonnegative)
217+
return opt_params, weights, biases, dWeights, dBiases
218+
219+
@resolver(_evolve)
220+
def evolve(self, opt_params, weights, biases, dWeights, dBiases):
221+
self.opt_params.set(opt_params)
222+
self.weights.set(weights)
223+
self.biases.set(biases)
224+
self.dWeights.set(dWeights)
225+
self.dBiases.set(dBiases)
226+
227+
@staticmethod
228+
def _reset(batch_size, shape):
229+
preVals = jnp.zeros((batch_size, shape[0]))
230+
postVals = jnp.zeros((batch_size, shape[1]))
231+
return (
232+
preVals, # inputs
233+
postVals, # outputs
234+
preVals, # pre
235+
postVals, # post
236+
jnp.zeros(shape), # dW
237+
jnp.zeros(shape[1]), # db
238+
)
239+
240+
@classmethod
241+
def help(cls): ## component help function
242+
properties = {
243+
"synapse_type": "HebbianSynapse - performs an adaptable synaptic "
244+
"transformation of inputs to produce output signals; "
245+
"synapses are adjusted via two-term/factor Hebbian adjustment"
246+
}
247+
compartment_props = {
248+
"inputs":
249+
{"inputs": "Takes in external input signal values",
250+
"pre": "Pre-synaptic statistic for Hebb rule (z_j)",
251+
"post": "Post-synaptic statistic for Hebb rule (z_i)"},
252+
"states":
253+
{"weights": "Synapse efficacy/strength parameter values",
254+
"biases": "Base-rate/bias parameter values",
255+
"key": "JAX PRNG key"},
256+
"analytics":
257+
{"dWeights": "Synaptic weight value adjustment matrix produced at time t",
258+
"dBiases": "Synaptic bias/base-rate value adjustment vector produced at time t"},
259+
"outputs":
260+
{"outputs": "Output of synaptic transformation"},
261+
}
262+
hyperparams = {
263+
"shape": "Overall shape of synaptic weight value matrix; number inputs x number outputs",
264+
"n_sub_models": "The number of submodels in each layer",
265+
"stride_shape": "Stride shape of overlapping synaptic weight value matrix",
266+
"batch_size": "Batch size dimension of this component",
267+
"weight_init": "Initialization conditions for synaptic weight (W) values",
268+
"bias_init": "Initialization conditions for bias/base-rate (b) values",
269+
"resist_scale": "Resistance level scaling factor (applied to output of transformation)",
270+
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)",
271+
"is_nonnegative": "Should synapses be constrained to be non-negative post-updates?",
272+
"sign_value": "Scalar `flipping` constant -- changes direction to Hebbian descent if < 0",
273+
"eta": "Global (fixed) learning rate",
274+
"pre_wght": "Pre-synaptic weighting coefficient (q_pre)",
275+
"post_wght": "Post-synaptic weighting coefficient (q_post)",
276+
"w_bound": "Soft synaptic bound applied to synapses post-update",
277+
"w_decay": "Synaptic decay term",
278+
"optim_type": "Choice of optimizer to adjust synaptic weights"
279+
}
280+
info = {cls.__name__: properties,
281+
"compartments": compartment_props,
282+
"dynamics": "outputs = [(W * Rscale) * inputs] + b ;"
283+
"dW_{ij}/dt = eta * [(z_j * q_pre) * (z_i * q_post)] - W_{ij} * w_decay",
284+
"hyperparameters": hyperparams}
285+
return info
286+
287+
@resolver(_reset)
288+
def reset(self, inputs, outputs, pre, post, dWeights, dBiases):
289+
self.inputs.set(inputs)
290+
self.outputs.set(outputs)
291+
self.pre.set(pre)
292+
self.post.set(post)
293+
self.dWeights.set(dWeights)
294+
self.dBiases.set(dBiases)
295+
296+
def __repr__(self):
297+
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
298+
maxlen = max(len(c) for c in comps) + 5
299+
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
300+
for c in comps:
301+
stats = tensorstats(getattr(self, c).value)
302+
if stats is not None:
303+
line = [f"{k}: {v}" for k, v in stats.items()]
304+
line = ", ".join(line)
305+
else:
306+
line = "None"
307+
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
308+
return lines
309+
310+
if __name__ == '__main__':
311+
from ngcsimlib.context import Context
312+
with Context("Bar") as bar:
313+
Wab = HebbianPatchedSynapse("Wab", (9, 30), 3)
314+
print(Wab)
315+
plt.imshow(Wab.weights.value, cmap='gray')
316+
plt.show()

0 commit comments

Comments
 (0)