Skip to content

Commit

Permalink
feat NGC module regression (#86)
Browse files Browse the repository at this point in the history
* feat npc module regression

* Update __init__.py

* Update __init__.py

* Update elastic_net.py

* Update lasso.py

* Update ridge.py

* Update elastic_net.py

* Update ridge.py

* Update lasso.py
  • Loading branch information
Faezehabibi authored Dec 9, 2024
1 parent eeb057a commit cf53968
Show file tree
Hide file tree
Showing 5 changed files with 480 additions and 0 deletions.
8 changes: 8 additions & 0 deletions ngclearn/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from regression.elastic_net import Iterative_ElasticNet
from regression.lasso import Iterative_Lasso
from regression.ridge import Iterative_Ridge





9 changes: 9 additions & 0 deletions ngclearn/modules/regression/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from elastic_net import Iterative_ElasticNet
from lasso import Iterative_Lasso
from ridge import Iterative_Ridge






153 changes: 153 additions & 0 deletions ngclearn/modules/regression/elastic_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from jax import random, jit
import numpy as np
from ngclearn.utils import weight_distribution as dist
from ngclearn import Context, numpy as jnp
from ngclearn.components import (RateCell,
HebbianSynapse,
GaussianErrorCell,
StaticSynapse)
from ngclearn.utils.model_utils import scanner


class Iterative_ElasticNet():
"""
A neural circuit implementation of the iterative Elastic Net (L1 and L2) algorithm
using Hebbian learning update rule.
The circuit implements sparse regression through Hebbian synapses with Elastic Net regularization.
The specific differential equation that characterizes this model is dW_reg (for adjusting W, given
dW (the gradient of loss/energy function), it adds lmbda * dW_reg to the dW)
| dW_reg = (jnp.sign(W) * l1_ratio) + (W * (1-l1_ratio)/2)
| dW/dt = dW + lmbda * dW_reg
| --- Circuit Components: ---
| W - HebbianSynapse for learning regularized dictionary weights
| err - GaussianErrorCell for computing prediction errors
| --- Component Compartments ---
| W.inputs - input features (takes in external signals)
| W.pre - pre-synaptic activity for Hebbian learning
| W.post - post-synaptic error signals
| W.weights - learned dictionary coefficients
| err.mu - predicted outputs
| err.target - target signals (target vector)
| err.dmu - error gradients
| err.L - loss/energy values
Args:
key: JAX PRNG key for random number generation
name: string name for this solver
sys_dim: dimensionality of the system/target space
dict_dim: dimensionality of the dictionary/feature space/the number of predictors
batch_size: number of samples to process in parallel
weight_fill: initial constant value to fill weight matrix with (Default: 0.05)
lr: learning rate for synaptic weight updates (Default: 0.01)
lmbda: elastic net regularization lambda parameter (Default: 0.0001)
optim_type: optimization type for updating weights; supported values are
"sgd" and "adam" (Default: "adam")
threshold: minimum absolute coefficient value - values below this are set
to zero during thresholding (Default: 0.001)
epochs: number of training epochs (Default: 100)
"""
def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, lr=0.01,
lmbda = 0.0001, l1_ratio=0.5, optim_type="adam", threshold=0.05, epochs=100):
key, *subkeys = random.split(key, 10)

## synaptic plasticity properties and characteristics
self.T = 100
self.dt = 1
self.epochs = epochs
self.weight_fill = weight_fill
self.threshold = threshold
self.name = name
feature_dim = dict_dim

with Context(self.name) as self.circuit:
self.W = HebbianSynapse("W", shape=(feature_dim, sys_dim), eta=lr,
sign_value=-1, weight_init=dist.constant(weight_fill),
prior=('elastic_net', (lmbda, l1_ratio)), optim_type=optim_type, key=subkeys[0])
self.err = GaussianErrorCell("err", n_units=sys_dim)

# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
self.W.batch_size = batch_size
self.err.batch_size = batch_size
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
self.err.mu << self.W.outputs
self.W.post << self.err.dmu
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
advance_cmd, advance_args =self.circuit.compile_by_key(self.W, ## execute prediction synapses
self.err, ## finally, execute error neurons
compile_key="advance_state")
evolve_cmd, evolve_args =self.circuit.compile_by_key(self.W, compile_key="evolve")
reset_cmd, reset_args =self.circuit.compile_by_key(self.err, self.W, compile_key="reset")
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
self.dynamic()

def dynamic(self): ## create dynamic commands forself.circuit
W, err = self.circuit.get_components("W", "err")
self.self = W
self.err = err

@Context.dynamicCommand
def batch_set(batch_size):
self.W.batch_size = batch_size
self.err.batch_size = batch_size

@Context.dynamicCommand
def clamps(y_scaled, X):
self.W.inputs.set(X)
self.W.pre.set(X)
self.err.target.set(y_scaled)

self.circuit.wrap_and_add_command(jit(self.circuit.evolve), name="evolve")
self.circuit.wrap_and_add_command(jit(self.circuit.advance_state), name="advance")
self.circuit.wrap_and_add_command(jit(self.circuit.reset), name="reset")


@scanner
def _process(compartment_values, args):
_t, _dt = args
compartment_values = self.circuit.advance_state(compartment_values, t=_t, dt=_dt)
return compartment_values, compartment_values[self.W.weights.path]


def thresholding(self, scale=1.):
coef_old = self.coef_
new_coeff = jnp.where(jnp.abs(coef_old) >= self.threshold, coef_old, 0.)

self.coef_ = new_coeff * scale
self.W.weights.set(new_coeff)

return self.coef_, coef_old


def fit(self, y, X):

self.circuit.reset()
self.circuit.clamps(y_scaled=y, X=X)

for i in range(self.epochs):
self.circuit._process(jnp.array([[self.dt * i, self.dt] for i in range(self.T)]))
self.circuit.evolve(t=self.T, dt=self.dt)

self.coef_ = np.array(self.W.weights.value)

return self.coef_, self.err.mu.value, self.err.L.value





157 changes: 157 additions & 0 deletions ngclearn/modules/regression/lasso.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import jax
import pandas as pd
from jax import random, jit
import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
from ngcsimlib.utils import Get_Compartment_Batch
from ngclearn.utils.model_utils import normalize_matrix
from ngclearn.utils import weight_distribution as dist
from ngclearn import Context, numpy as jnp
from ngclearn.components import (RateCell,
HebbianSynapse,
GaussianErrorCell,
StaticSynapse)
from ngclearn.utils.model_utils import scanner


class Iterative_Lasso():
"""
A neural circuit implementation of the iterative Lasso (L1) algorithm
using Hebbian learning update rule.
The circuit implements sparse coding through Hebbian synapses with L1 regularization.
The specific differential equation that characterizes this model is adding lmbda * sign(W)
to the dW (the gradient of loss/energy function):
| dW/dt = dW + lmbda * sign(W)
| --- Circuit Components: ---
| W - HebbianSynapse for learning sparse dictionary weights
| err - GaussianErrorCell for computing prediction errors
| --- Component Compartments ---
| W.inputs - input features (takes in external signals)
| W.pre - pre-synaptic activity for Hebbian learning
| W.post - post-synaptic error signals
| W.weights - learned dictionary coefficients
| err.mu - predicted outputs
| err.target - target signals (target vector)
| err.dmu - error gradients
| err.L - loss/energy values
Args:
key: JAX PRNG key for random number generation
name: string name for this solver
sys_dim: dimensionality of the system/target space
dict_dim: dimensionality of the dictionary/feature space/the number of predictors
batch_size: number of samples to process in parallel
weight_fill: initial constant value to fill weight matrix with (Default: 0.05)
lr: learning rate for synaptic weight updates (Default: 0.01)
lasso_lmbda: L1 regularization lambda parameter (Default: 0.0001)
optim_type: optimization type for updating weights; supported values are
"sgd" and "adam" (Default: "adam")
threshold: minimum absolute coefficient value - values below this are set
to zero during thresholding (Default: 0.001)
epochs: number of training epochs (Default: 100)
"""

# Define Functions
def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, lr=0.01,
lasso_lmbda=0.0001, optim_type="adam", threshold=0.001, epochs=100):
key, *subkeys = random.split(key, 10)

self.T = 100
self.dt = 1
self.epochs = epochs
self.weight_fill = weight_fill
self.threshold = threshold
self.name = name
feature_dim = dict_dim

with Context(self.name) as self.circuit:
self.W = HebbianSynapse("W", shape=(feature_dim, sys_dim), eta=lr,
sign_value=-1, weight_init=dist.constant(weight_fill),
prior=('lasso', lasso_lmbda),
optim_type=optim_type, key=subkeys[0])
self.err = GaussianErrorCell("err", n_units=sys_dim)
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
self.W.batch_size = batch_size
self.err.batch_size = batch_size
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
self.err.mu << self.W.outputs
self.W.post << self.err.dmu
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
advance_cmd, advance_args =self.circuit.compile_by_key(self.W, ## execute prediction synapses
self.err, ## finally, execute error neurons
compile_key="advance_state")
evolve_cmd, evolve_args =self.circuit.compile_by_key(self.W, compile_key="evolve")
reset_cmd, reset_args =self.circuit.compile_by_key(self.err, self.W, compile_key="reset")
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
self.dynamic()

def dynamic(self): ## create dynamic commands for self.circuit
W, err = self.circuit.get_components("W", "err")
self.self = W
self.err = err

@Context.dynamicCommand
def batch_set(batch_size):
self.W.batch_size = batch_size
self.err.batch_size = batch_size

@Context.dynamicCommand
def clamps(y_scaled, X):
self.W.inputs.set(X)
self.W.pre.set(X)
self.err.target.set(y_scaled)

self.circuit.wrap_and_add_command(jit(self.circuit.evolve), name="evolve")
self.circuit.wrap_and_add_command(jit(self.circuit.advance_state), name="advance")
self.circuit.wrap_and_add_command(jit(self.circuit.reset), name="reset")

@scanner
def _process(compartment_values, args):
_t, _dt = args
compartment_values = self.circuit.advance_state(compartment_values, t=_t, dt=_dt)
return compartment_values, compartment_values[self.W.weights.path]


def thresholding(self, scale=2):
coef_old = self.coef_
new_coeff = jnp.where(jnp.abs(coef_old) >= self.threshold, coef_old, 0.)

self.coef_ = new_coeff * scale
self.W.weights.set(new_coeff)

return self.coef_, coef_old


def fit(self, y, X):

self.circuit.reset()
self.circuit.clamps(y_scaled=y, X=X)

for i in range(self.epochs):
self.circuit._process(jnp.array([[self.dt * i, self.dt] for i in range(self.T)]))
self.circuit.evolve(t=self.T, dt=self.dt)

self.coef_ = np.array(self.W.weights.value)

return self.coef_, self.err.mu.value, self.err.L.value







Loading

0 comments on commit cf53968

Please sign in to comment.