Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions examples/nestml_neuron.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
r"""
Example of using a cell type defined in NESTML
"""

import sys
from copy import deepcopy
from pyNN.utility import init_logging, get_simulator, normalized_filename
import pyNN
import pyNN.nest
import pyNN.nest.nestml



sim, options = get_simulator(("--plot-figure", "plot a figure with the given filename"))
init_logging(None, debug=True)
sim.setup(timestep=0.1, min_delay=0.1, max_delay=2.0)
celltype_cls = pyNN.nest.nestml.nestml_celltype_from_model(nestml_file_name="izhikevich_neuron.nestml")

parameters = {
'a': .02,
'b': .2,
'c': -65.,
'd': 8.
}

print(celltype_cls.default_parameters)

cells = sim.Population(1, celltype_cls, parameters)
cells.initialize(V_m=-70.)

input = sim.Population(2, sim.SpikeSourcePoisson, {'rate': 500})

connector = sim.OneToOneConnector()
syn = sim.StaticSynapse(weight=5.0, delay=0.5)
conn = [sim.Projection(input[0:1], cells, connector, syn)]

cells.record(('V_m', 'U_m'))

sim.run(100.0)

cells.write_data(
normalized_filename("Results", "nestml_cell", "pkl",
options.simulator, sim.num_processes()),
annotations={'script_name': __file__})

data = cells.get_data().segments[0]

sim.end()

if options.plot_figure:
from pyNN.utility.plotting import Figure, Panel

Figure(
Panel(data.filter(name='V_m')[0], ylabel="V_m", xlabel="Time (ms)", xticks=True),
Panel(data.filter(name='U_m')[0], ylabel="U_m", xlabel="Time (ms)", xticks=True),
title=__file__
).save(options.plot_figure)

print(data.spiketrains)
80 changes: 80 additions & 0 deletions izhikevich_neuron.nestml
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# izhikevich - Izhikevich neuron model
# ####################################
#
# Description
# +++++++++++
#
# Implementation of the simple spiking neuron model introduced by Izhikevich [1]_. The dynamics are given by:
#
# .. math::
#
# dV_{m}/dt &= 0.04 V_{m}^2 + 5 V_{m} + 140 - U_{m} + I\\
# dU_{m}/dt &= a (b V_{m} - U_{m})
#
#
# .. math::
#
# &\text{if}\;\; V_{m} \geq V_{th}:\\
# &\;\;\;\; V_{m} \text{ is set to } c\\
# &\;\;\;\; U_{m} \text{ is incremented by } d\\
# & \, \\
# &V_{m} \text{ jumps on each spike arrival by the weight of the spike}
#
# Incoming spikes cause an instantaneous jump in the membrane potential proportional to the strength of the synapse.
#
# As published in [1]_, the numerics differs from the standard forward Euler technique in two ways:
#
# 1) the new value of :math:`U_{m}` is calculated based on the new value of :math:`V_{m}`, rather than the previous value
# 2) the variable :math:`V_{m}` is updated using a time step half the size of that used to update variable :math:`U_{m}`.
#
# This model will instead be simulated using the numerical solver that is recommended by ODE-toolbox during code generation.
#
#
# References
# ++++++++++
#
# .. [1] Izhikevich, Simple Model of Spiking Neurons, IEEE Transactions on Neural Networks (2003) 14:1569-1572
#
#
model izhikevich_neuron:
state:
V_m mV = V_m_init # Membrane potential
U_m real = b * V_m_init # Membrane potential recovery variable

equations:
V_m' = ( 0.04 * V_m * V_m / mV + 5.0 * V_m + ( 140 - U_m ) * mV + ( (I_e + I_stim) * GOhm ) ) / ms
U_m' = a*(b*V_m-U_m * mV) / (mV*ms)

parameters:
a real = 0.02 # describes time scale of recovery variable
b real = 0.2 # sensitivity of recovery variable
c mV = -65 mV # after-spike reset value of V_m
d real = 8.0 # after-spike reset value of U_m
V_m_init mV = -65 mV # initial membrane potential
V_min mV = -inf * mV # Absolute lower value for the membrane potential.
V_th mV = 30 mV # Threshold potential

# constant external input current
I_e pA = 0 pA

input:
spikes <- spike
I_stim pA <- continuous

output:
spike

update:
integrate_odes()

# Add synaptic current
V_m += spikes * mV * s

# lower bound of membrane potential
V_m = max(V_min, V_m)

onCondition(V_m >= V_th):
# threshold crossing
V_m = c
U_m += d
emit_spike()
76 changes: 76 additions & 0 deletions pyNN/nest/nestml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# -*- coding: utf-8 -*-
"""
Support cell types defined in NESTML (https://nestml.readthedocs.org/).

Requires NESTML to be installed.

:copyright: Copyright 2006-2024 by the PyNN team, see AUTHORS.
:license: CeCILL, see LICENSE for details.
"""

import logging
import pynestml
import pynestml.frontend
import pynestml.frontend.pynestml_frontend
from pynestml.utils.model_parser import ModelParser
from pynestml.codegeneration.python_standalone_target_tools import PythonStandaloneTargetTools
from pyNN.nest.cells import NativeCellType


logger = logging.getLogger("PyNN")


class NESTMLCellType(NativeCellType):

def __init__(self, parameters):
NativeCellType.__init__(self, parameters)


def nestml_celltype_from_model(nestml_file_name: str):
"""
Return a new NativeCellType subclass from a NESTML model.
"""

dct = {'nestml_file_name': nestml_file_name}
return _nest_build_nestml_celltype((NESTMLCellType,), dct)


class _nest_build_nestml_celltype(type):
"""
Metaclass for building NESTMLCellType subclasses
"""
def __new__(cls, bases, dct):
import nest
import pynestml

nestml_file_name = dct['nestml_file_name']

pynestml.frontend.pynestml_frontend.generate_target(input_path=nestml_file_name,
target_platform="NEST",
suffix="_nestml",
logging_level="WARNING")

ast_compilation_unit = ModelParser.parse_file(nestml_file_name)
if ast_compilation_unit is None or len(ast_compilation_unit.get_model_list()) == 0:
raise("Error(s) occurred during code generation; please check error messages")

model: ASTModel = ast_compilation_unit.get_model_list()[0]
model_name = model.get_name()

dct["default_parameters"], dct["default_initial_values"] = PythonStandaloneTargetTools.get_neuron_parameters_and_state(nestml_file_name)
dct["synapse_types"] = [port.name for port in model.get_spike_input_ports()]
dct["standard_receptor_type"] = ()
dct["injectable"] = bool(model.get_continuous_input_ports()) # assume that in case there is a continuous-time input port, it corresponds with a current injection port
dct["conductance_based"] = False # this is only used for checking sign of the incoming weights -- assume always false to skip the check
dct["model_name"] = model_name
dct["nest_model"] = model_name

# Recording from bindings:
dct["recordable"] = dct["default_initial_values"].keys()
# XXX TODO: add recordable inlines

dct["weight_variables"] = [] # XXX: none for neuron models, no?

logger.debug("Creating class '%s' with bases %s and dictionary %s" % (model_name, bases, dct))

return type.__new__(cls, model_name, bases, dct)
Loading