diff --git a/.github/workflows/ode-toolbox-build.yml b/.github/workflows/ode-toolbox-build.yml index 892af236..a5223761 100644 --- a/.github/workflows/ode-toolbox-build.yml +++ b/.github/workflows/ode-toolbox-build.yml @@ -22,6 +22,7 @@ jobs: run: | python -m pip install --upgrade pip pytest pycodestyle codecov pytest-cov wheel python -m pip install -r requirements.txt + python -m pip install -r requirements-testing.txt export PYTHON_VERSION=`python -c "import sys; print('.'.join(map(str, [sys.version_info.major, sys.version_info.minor])))"` echo "Python version detected:" echo $PYTHON_VERSION @@ -33,7 +34,7 @@ jobs: - name: Static code style analysis run: | - python3 -m pycodestyle $GITHUB_WORKSPACE -v --ignore=E241,E501,E303,E714,E713,E714,E252 --exclude=$GITHUB_WORKSPACE/doc,$GITHUB_WORKSPACE/.eggs,$GITHUB_WORKSPACE/build,$GITHUB_WORKSPACE/.git,$GITHUB_WORKSPACE/odetoolbox.egg-info,$GITHUB_WORKSPACE/dist + python3 -m pycodestyle $GITHUB_WORKSPACE -v --ignore=E241,E501,E714,E713,E714,E252 --exclude=$GITHUB_WORKSPACE/doc,$GITHUB_WORKSPACE/.eggs,$GITHUB_WORKSPACE/build,$GITHUB_WORKSPACE/.git,$GITHUB_WORKSPACE/odetoolbox.egg-info,$GITHUB_WORKSPACE/dist build: @@ -75,6 +76,7 @@ jobs: python -m pip install numpy if [ "${{ matrix.with_gsl }}" == "1" ]; then python3 -m pip install -v https://github.com/pygsl/pygsl/archive/refs/tags/v2.4.1.tar.gz ; fi # this should be "pip install pygsl", but see https://github.com/pygsl/pygsl/issues/59 python -m pip install -r requirements.txt + python -m pip install -r requirements-testing.txt export PYTHON_VERSION=`python -c "import sys; print('.'.join(map(str, [sys.version_info.major, sys.version_info.minor])))"` echo "Python version detected:" echo $PYTHON_VERSION diff --git a/doc/index.rst b/doc/index.rst index 05380335..ba474764 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -44,18 +44,27 @@ Installation .. Attention:: To perform solver benchmarking, ODE-toolbox relies on GSL and PyGSL. Currently, the latest PyGSL release is not compatible with GSL. We recommend to use GSL 2.7 for now. This issue is being tracked at https://github.com/pygsl/pygsl/issues/62. +.. Attention:: Versions of sympy before 1.14.0 can introduce numerical precision errors and very long processing times. It is recommended to use the latest sympy version available. + Prerequisites ~~~~~~~~~~~~~ Only Python 3 is supported. ODE-toolbox depends on the Python packages SymPy, Cython, SciPy and NumPy (required), matplotlib and graphviz for visualisation (optional), and pytest for self-tests (also optional). The stiffness tester additionally depends on an installation of `PyGSL `__. If PyGSL is not installed, the test for stiffness is skipped during the analysis of the equations. -All required and optional packages can be installed by running +All required packages can be installed by running .. code:: bash pip install -r requirements.txt +Optional packages for testing can be installed by running + +.. code:: bash + + pip install -r requirements-testing.txt + + Installing ODE-toolbox ~~~~~~~~~~~~~~~~~~~~~~ @@ -376,14 +385,47 @@ ODE-toolbox will return a list of solvers. **Each solver has the following keys: - :python:`"initial_values"`\ : a dictionary that maps each variable symbol (in string form) to a SymPy expression. For example :python:`"g" : "e / tau"`. - :python:`"parameters"`\ : only present when parameters were supplied in the input. The input parameters are copied into the output for convenience. +**Numeric solvers have the following extra entries:** + +- :python:`"update_expressions"`\ : a dictionary that maps each variable symbol (in string form) to a SymPy expression that is its Jacobian, that is, for a symbol :math:`x`, the expression is equal to :math:`\frac{\delta x}{\delta t}`. + **Analytic solvers have the following extra entries:** -- :python:`"update_expressions"`\ : a dictionary that maps each variable symbol (in string form) to a SymPy propagator expression. The interpretation of an entry :python:`"g" : "g * __P__g__g + h * __P__g__h"` is that, at each integration timestep, when the state of the system needs to be updated from the current time :math:`t` to the next step :math:`t + \Delta t`, we assign the new value :python:`"g * __P__g__g + h * __P__g__h"` to the variable :python:`g`. Note that the expression is always evaluated at the old time :math:`t`; this means that when more than one state variable needs to be updated, all of the expressions have to be calculated before updating any of the variables. -- :python:`propagators`\ : a dictionary that maps each propagator matrix entry to its defining expression; for example :python:`"__P__g__h" : "__h*exp(-__h/tau)"` +In case of a solver without conditions: -**Numeric solvers have the following extra entries:** +- :python:`"update_expressions"`\ : a dictionary that maps each variable symbol (in string form) to a SymPy propagator expression. The interpretation of an entry :python:`"g" : "g * __P__g__g + h * __P__g__h"` is that, at each integration timestep, when the state of the system needs to be updated from the current time :math:`t` to the next step :math:`t + \Delta t`, we assign the new value :python:`"g * __P__g__g + h * __P__g__h"` to the variable :python:`g`. Note that the expression is always evaluated at the old time :math:`t`; this means that when more than one state variable needs to be updated, all of the expressions have to be calculated before updating any of the variables. +- :python:`propagators`\ : a dictionary that maps each propagator matrix entry to its defining expression; for example :python:`"__P__g__h" : "__h*exp(-__h/tau)"` -- :python:`"update_expressions"`\ : a dictionary that maps each variable symbol (in string form) to a SymPy expression that is its Jacobian, that is, for a symbol :math:`x`, the expression is equal to :math:`\frac{\delta x}{\delta t}`. +In some cases, parameter choices of the input can lead to numerical singularities (divisions by zero; see the sections :ref:`Computing the propagator matrix` and :ref:`Computing the update expressions` for more details). In this case, certain propagators and update expressions are only valid under certain conditions. These conditions, and the propagators and update expressions for each condition, are returned as follows: + +- :python:`"conditions"`\ : a dictionary that maps conditional expressions (as strings) to a nested dictionary containing the keys :python:`"update_expressions"` and :python:`"propagators"` as described in the previous paragraph. The default solver (in case none of the other conditions hold) is indicated by the key :python:`"default"`\ . + + For example, if the condition :math:`d=-p` will cause a singularity in the update expression for the state variable :python:`z`, two separate solvers are returned, one for the condition :math:`d=-p`, and another for the default ("otherwise") condition: + + .. code:: python + + { + "conditions": { + "(d == -p)": { + "propagators": { + "__P__z__z": "1" + }, + "update_expressions": { + "z": "__P__z__z*z + 1.5*__h*p/tau_z" + } + }, + "default": { + "propagators": { + "__P__z__z": "exp(-__h*(d + p)/tau_z)" + }, + "update_expressions": { + "z": "(__P__z__z*(0.5*d - p + z*(d + p)) - 0.5*d + p)/(d + p)" + } + } + } + } + + In general, there can be from 2 to any number of conditions. Conditions can involve boolean logic through the ``"&&"`` symbol for logical AND, ``"||"`` for logical OR, and the use of parentheses. The equality symbol, separating the left- and right-hand sides of the comparison, is written as ``"=="``. Analytic solver generation @@ -480,7 +522,7 @@ The propagator matrix :math:`\mathbf{P}` is derived from the system matrix by ma If the imaginary unit :math:`i` is found in any of the entries in :math:`\mathbf{P}`, fail. This usually indicates an unstable (diverging) dynamical system. Double-check the dynamical equations. -In some cases, elements of :math:`\mathbf{P}` may contain fractions that have a factor of the form :python:`param1 - param2` in their denominator. If at a later stage, the numerical value of :python:`param1` is chosen equal to that of :python:`param2`, a numerical singularity (division by zero) occurs. To avoid this issue, it is necessary to eliminate either :python:`param1` or :python:`param2` in the input, before the propagator matrix is generated. ODE-toolbox will detect conditions (in this example, :python:`param1 = param2`) under which these singularities can occur. If any conditions were found, log warning messages will be emitted during the computation of the propagator matrix. A condition is only reported if the system matrix :math:`A` is defined under that condition, ensuring that only those conditions are returned that are purely an artifact of the propagator computation. +In some cases, elements of :math:`\mathbf{P}` may contain fractions that have a factor of the form :python:`param1 - param2` in their denominator. If at a later stage, the numerical value of :python:`param1` is chosen equal to that of :python:`param2`, a numerical singularity (division by zero) occurs. To avoid this issue, it is necessary to eliminate either :python:`param1` or :python:`param2` in the input, before the propagator matrix is generated. ODE-toolbox will detect conditions (in this example, :python:`param1 = param2`) under which these singularities can occur. In case a potential division by zero is detected, separate, conditional solvers are generated, so that a valid solver can be selected (for the given choice of parameter values) during numerical integration. Internally, conditions are generated based on `sympy.calculus.singularities.singularities `_. To speed up processing, the final system matrix :math:`\mathbf{A}` is rewritten as a block-diagonal matrix :math:`\mathbf{A} = \text{diag}(\mathbf{A}_1, \mathbf{A}_2, \dots, \mathbf{A}_k)`, where each of :math:`\mathbf{A}_1, \mathbf{A}_2, \dots, \mathbf{A}_k` is square. Then, the propagator matrix is computed for each individual block separately, making use of the following identity: @@ -535,7 +577,7 @@ the update equation is: x \leftarrow P (x - 1.618) + 1.618 -In some cases, elements of :math:`\mathbf{A}` may contain terms that involve a parameter of the system to be integrated. If at a later stage, the numerical value of these parameters is chosen equal to zero, a numerical singularity (division by zero) occurs. To avoid this issue, it is necessary to invoke ODE-toolbox separately to generate an analytic solver for the special case that the value of :math:`\mathbf{A}` becomes equal to zero for inhomogeneous equations. ODE-toolbox will detect the conditions (for instance, :python:`param1 - param2 = 0`) under which these singularities occur. If any conditions were found, log warning messages will be emitted during the computation of the analytic solver. +In some cases, elements of :math:`\mathbf{A}` may contain terms that involve a parameter of the system to be integrated. If at a later stage, the numerical value of these parameters is chosen equal to zero, a numerical singularity (division by zero) occurs. ODE-toolbox will detect the conditions (for instance, :python:`param1 - param2 = 0`) under which these singularities occur. In case potential division by zero is detected, separate, conditional solvers are generated, so that a valid solver can be selected (for the given choice of parameter values) during numerical integration. Working with large expressions diff --git a/odetoolbox/__init__.py b/odetoolbox/__init__.py index 89c42cc6..499ff160 100644 --- a/odetoolbox/__init__.py +++ b/odetoolbox/__init__.py @@ -27,7 +27,7 @@ from sympy.core.expr import Expr as SympyExpr from .config import Config -from .sympy_helpers import _check_numerical_issue, _check_forbidden_name, _find_in_matrix, _is_zero, _is_sympy_type, SympyPrinter +from .sympy_helpers import _check_numerical_issue, _check_forbidden_name, _find_in_matrix, _is_zero, _is_sympy_type, SympyPrinter, _sympy_parse_real from .system_of_shapes import SystemOfShapes from .shapes import MalformedInputException, Shape @@ -127,6 +127,7 @@ def _from_json_to_shapes(indict, parameters=None) -> Tuple[List[Shape], Dict[sym # validate input for forbidden names for var in set(all_variable_symbols) | all_parameter_symbols: _check_forbidden_name(var) + assert var.is_real # validate parameters for param in all_parameter_symbols: @@ -182,6 +183,32 @@ def _get_all_first_order_variables(indict) -> Iterable[str]: return variable_names +def symbol_appears_in_any_expr(param_name, solver_json) -> bool: + if "update_expressions" in solver_json.keys(): + for sym, expr in solver_json["update_expressions"].items(): + if param_name in [str(sym) for sym in list(expr.atoms())]: + return True + + if "propagators" in solver_json.keys(): + for sym, expr in solver_json["propagators"].items(): + if param_name in [str(sym) for sym in list(expr.atoms())]: + return True + + if "conditions" in solver_json.keys(): + for conditional_solver_json in solver_json["conditions"].values(): + if "update_expressions" in conditional_solver_json.keys(): + for sym, expr in conditional_solver_json["update_expressions"].items(): + if param_name in [str(sym) for sym in list(expr.atoms())]: + return True + + if "propagators" in conditional_solver_json.keys(): + for sym, expr in solver_json["propagators"].items(): + if param_name in [str(sym) for sym in list(expr.atoms())]: + return True + + return False + + def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_solver: bool = False, disable_singularity_detection: bool = False, preserve_expressions: Union[bool, Iterable[str]] = False, log_level: Union[str, int] = logging.WARNING) -> Tuple[List[Dict], SystemOfShapes, List[Shape]]: r""" Like analysis(), but additionally returns ``shape_sys`` and ``shapes``. @@ -208,14 +235,13 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so parameters = {} for k, v in indict["parameters"].items(): if type(k) is str: - parameters[sympy.Symbol(k)] = v + parameters[sympy.Symbol(k, real=True)] = v else: assert type(k) is sympy.Symbol parameters[k] = v _check_forbidden_name(k) - # # create Shapes and SystemOfShapes # @@ -236,7 +262,6 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so logging.debug("b = " + str(shape_sys.b_)) logging.debug("c = " + str(shape_sys.c_)) - # # generate analytical solutions (propagators) where possible # @@ -255,7 +280,6 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so analytic_solver_json["solver"] = "analytical" solvers_json.append(analytic_solver_json) - # # generate numerical solvers for the remainder # @@ -293,7 +317,6 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so solvers_json.append(solver_json) - # # copy the initial values from the input to the output for convenience; convert to numeric values # @@ -301,7 +324,7 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so for solver_json in solvers_json: solver_json["initial_values"] = {} for shape in shapes: - all_shape_symbols = [str(sympy.Symbol(str(shape.symbol) + Config().differential_order_symbol * i)) for i in range(shape.order)] + all_shape_symbols = [str(sympy.Symbol(str(shape.symbol) + Config().differential_order_symbol * i, real=True)) for i in range(shape.order)] for sym in all_shape_symbols: if sym in solver_json["state_variables"]: iv_expr = shape.get_initial_value(sym.replace(Config().differential_order_symbol, "'")) @@ -320,21 +343,8 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so solver_json["parameters"] = {} for param_name, param_expr in indict["parameters"].items(): # only make parameters appear in a solver if they are actually used there - symbol_appears_in_any_expr = False - if "update_expressions" in solver_json.keys(): - for sym, expr in solver_json["update_expressions"].items(): - if param_name in [str(sym) for sym in list(expr.atoms())]: - symbol_appears_in_any_expr = True - break - - if "propagators" in solver_json.keys(): - for sym, expr in solver_json["propagators"].items(): - if param_name in [str(sym) for sym in list(expr.atoms())]: - symbol_appears_in_any_expr = True - break - - if symbol_appears_in_any_expr: - sympy_expr = sympy.parsing.sympy_parser.parse_expr(param_expr, global_dict=Shape._sympy_globals) + if symbol_appears_in_any_expr(sym, solver_json): + sympy_expr = _sympy_parse_real(param_expr, global_dict=Shape._sympy_globals) # validate output for numerical problems for var in sympy_expr.atoms(): @@ -349,7 +359,6 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so solver_json["parameters"][param_name] = str(sympy_expr) - # # convert expressions from sympy to string # @@ -388,6 +397,26 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so for sym, expr in solver_json["propagators"].items(): solver_json["propagators"][sym] = str(expr) + if "conditions" in solver_json.keys(): + for cond, cond_solver in solver_json["conditions"].items(): + if "update_expressions" in cond_solver: + for sym, expr in cond_solver["update_expressions"].items(): + cond_solver["update_expressions"][sym] = str(expr) + + if preserve_expressions and sym in preserve_expressions: + if "analytic" in solver_json["solver"]: + logging.warning("Not preserving expression for variable \"" + sym + "\" as it is solved by propagator solver") + continue + + logging.info("Preserving expression for variable \"" + sym + "\"") + var_def_str = _find_variable_definition(indict, sym, order=1) + assert var_def_str is not None + cond_solver["update_expressions"][sym] = var_def_str.replace("'", Config().differential_order_symbol) + + if "propagators" in cond_solver: + for sym, expr in cond_solver["propagators"].items(): + cond_solver["propagators"][sym] = str(expr) + logging.info("In ode-toolbox: returning outdict = ") logging.info(json.dumps(solvers_json, indent=4, sort_keys=True)) diff --git a/odetoolbox/analytic_integrator.py b/odetoolbox/analytic_integrator.py index 2bbe6b1a..54b9b6e7 100644 --- a/odetoolbox/analytic_integrator.py +++ b/odetoolbox/analytic_integrator.py @@ -1,4 +1,3 @@ -# # analytic_integrator.py # # This file is part of the NEST ODE toolbox. @@ -19,13 +18,17 @@ # along with NEST. If not, see . # -from typing import Dict, List, Optional +import logging +from typing import Dict, List, Optional, Union import sympy import sympy.matrices import sympy.utilities import sympy.utilities.autowrap +from odetoolbox.config import Config +from odetoolbox.sympy_helpers import SymmetricEq, _sympy_parse_real + from .shapes import Shape from .integrator import Integrator @@ -47,7 +50,7 @@ def __init__(self, solver_dict, spike_times: Optional[Dict[str, List[float]]] = self.solver_dict = solver_dict self.all_variable_symbols = self.solver_dict["state_variables"] - self.all_variable_symbols = [sympy.Symbol(s) for s in self.all_variable_symbols] + self.all_variable_symbols = [sympy.Symbol(s, real=True) for s in self.all_variable_symbols] self.set_spike_times(spike_times) @@ -55,28 +58,33 @@ def __init__(self, solver_dict, spike_times: Optional[Dict[str, List[float]]] = self.enable_cache_update_ = True self.t = 0. - # # define the necessary numerical state variables # self.dim = len(self.all_variable_symbols) - self.initial_values = self.solver_dict["initial_values"].copy() + self.initial_values = {sympy.Symbol(k, real=True): v for k, v in self.solver_dict["initial_values"].items()} self.set_initial_values(self.initial_values) - self.shape_starting_values = self.solver_dict["initial_values"].copy() - for k, v in self.shape_starting_values.items(): - expr = sympy.parsing.sympy_parser.parse_expr(v, global_dict=Shape._sympy_globals) + self.shape_starting_values = {sympy.Symbol(k, real=True): v for k, v in self.solver_dict["initial_values"].items()} + for sym, v in self.shape_starting_values.items(): + expr = _sympy_parse_real(v, global_dict=Shape._sympy_globals) subs_dict = {} if "parameters" in self.solver_dict.keys(): - for k_, v_ in self.solver_dict["parameters"].items(): - subs_dict[k_] = v_ - self.shape_starting_values[k] = float(expr.evalf(subs=subs_dict)) + for parameter_name, v_ in self.solver_dict["parameters"].items(): + parameter_symbol = sympy.Symbol(parameter_name, real=True) + subs_dict[parameter_symbol] = v_ - self.update_expressions = self.solver_dict["update_expressions"].copy() - for k, v in self.update_expressions.items(): - if type(self.update_expressions[k]) is str: - self.update_expressions[k] = sympy.parsing.sympy_parser.parse_expr(self.update_expressions[k], global_dict=Shape._sympy_globals) + self.shape_starting_values[sym] = float(expr.evalf(subs=subs_dict)) + # + # initialise update expressions depending on whether conditional solver or not + # + + if "update_expressions" in self.solver_dict.keys(): + self._pick_unconditional_solver() + else: + assert "conditions" in self.solver_dict.keys() + self._pick_solver_based_on_condition() # # reset the system to t = 0 @@ -84,25 +92,118 @@ def __init__(self, solver_dict, spike_times: Optional[Dict[str, List[float]]] = self.reset() + def _condition_holds(self, condition_string) -> bool: + r"""Check boolean conditions of the form: + + :: + + (p_1 && p_2 .. && p_k) || (q_1 && q_2 .. && q_j) || ... + + """ + for sub_condition_string in condition_string.split("||"): + # if any of the subterms hold, the whole expression holds (OR-ed together) + if self._and_condition_holds(sub_condition_string): + return True + + return False + + def _and_condition_holds(self, condition_string) -> bool: + r"""Check boolean conditions of the form: + + :: + + p_1 && p_2 .. && p_k + + """ + sub_conditions = condition_string.split("&&") + for sub_condition_string in sub_conditions: + sub_condition_string = sub_condition_string.strip().strip("()") + if "==" in sub_condition_string: + parts = sub_condition_string.split("==") + else: + parts = sub_condition_string.split("!=") + + lhs_str = parts[0].strip() + rhs_str = parts[1].strip() + lhs = _sympy_parse_real(lhs_str, global_dict=Shape._sympy_globals) + rhs = _sympy_parse_real(rhs_str, global_dict=Shape._sympy_globals) + + if "==" in sub_condition_string: + equation = SymmetricEq(lhs, rhs) + else: + equation = sympy.Ne(lhs, rhs) + + subs_dict = {} + if "parameters" in self.solver_dict.keys(): + for param_name, param_val in self.solver_dict["parameters"].items(): + param_symbol = sympy.Symbol(param_name, real=True) + subs_dict[param_symbol] = param_val + + sub_condition_holds = equation.subs(subs_dict) + + if not sub_condition_holds: + # if any of the subterms do not hold, the whole expression does not hold (AND-ed together) + return False + + return True + + def _pick_unconditional_solver(self): + self.update_expressions = self.solver_dict["update_expressions"].copy() + self.propagators = self.solver_dict["propagators"].copy() + self._process_update_expressions_from_solver_dict() + + def _pick_solver_based_on_condition(self): + r"""In case of a conditional propagator solver: pick a solver depending on the conditions that hold (depending on parameter values)""" + self.update_expressions = self.solver_dict["conditions"]["default"]["update_expressions"] + self.propagators = self.solver_dict["conditions"]["default"]["propagators"] + for condition, conditional_solver in self.solver_dict["conditions"].items(): + if condition != "default" and self._condition_holds(condition): + self.update_expressions = conditional_solver["update_expressions"] + self.propagators = conditional_solver["propagators"] + logging.debug("Picking solver based on condition: " + str(condition)) + + break + + self._process_update_expressions_from_solver_dict() + + def _process_update_expressions_from_solver_dict(self): # - # in the update expression, replace symbolic variables with their numerical values + # create substitution dictionary to replace symbolic variables with their numerical values # - self.subs_dict = {} - for prop_symbol, prop_expr in self.solver_dict["propagators"].items(): - self.subs_dict[prop_symbol] = prop_expr + subs_dict = {} + for prop_name, prop_expr in self.propagators.items(): + subs_dict[prop_name] = prop_expr + if "parameters" in self.solver_dict.keys(): - for param_symbol, param_expr in self.solver_dict["parameters"].items(): - self.subs_dict[param_symbol] = param_expr + for param_name, param_expr in self.solver_dict["parameters"].items(): + subs_dict[param_name] = param_expr + # subs_dict = {sympy.Symbol(k, real=True): v for k, v in subs_dict.items()} + subs_dict = {sympy.Symbol(k, real=True): v if type(v) is float or isinstance(v, sympy.Expr) else _sympy_parse_real(v, global_dict=Shape._sympy_globals) for k, v in subs_dict.items()} + + # + # parse the expressions from JSON if necessary + # + + for k, v in self.update_expressions.items(): + if type(self.update_expressions[k]) is str: + self.update_expressions[k] = _sympy_parse_real(self.update_expressions[k], global_dict=Shape._sympy_globals) # # perform substitution in update expressions ahead of time to save time later # for k, v in self.update_expressions.items(): - self.update_expressions[k] = self.update_expressions[k].subs(self.subs_dict).subs(self.subs_dict) + for sym in self.update_expressions[k].free_symbols: + assert sym.is_real + self.update_expressions[k] = self.update_expressions[k].subs(subs_dict) + for sym in self.update_expressions[k].free_symbols: + assert sym.is_real + self.update_expressions[k] = self.update_expressions[k].subs(subs_dict) + for sym in self.update_expressions[k].free_symbols: + assert sym.is_real # # autowrap @@ -111,29 +212,25 @@ def __init__(self, solver_dict, spike_times: Optional[Dict[str, List[float]]] = self.update_expressions_wrapped = {} for k, v in self.update_expressions.items(): self.update_expressions_wrapped[k] = sympy.utilities.autowrap.autowrap(v, - args=[sympy.Symbol("__h")] + self.all_variable_symbols, + args=[sympy.Symbol(Config().output_timestep_symbol, real=True)] + self.all_variable_symbols, backend="cython", helpers=Shape._sympy_autowrap_helpers) - def get_all_variable_symbols(self): return self.all_variable_symbols - def enable_cache_update(self): r""" Allow caching of results between requested times. """ self.enable_cache_update_ = True - def disable_cache_update(self): r""" Disallow caching of results between requested times. """ self.enable_cache_update_ = False - def reset(self): r""" Reset time to zero and state to initial values. @@ -141,35 +238,45 @@ def reset(self): self.t_curr = 0. self.state_at_t_curr = self.initial_values.copy() - - def set_initial_values(self, vals): + def set_initial_values(self, vals: Union[Dict[str, str], Dict[sympy.Symbol, sympy.Expr]]): r""" Set initial values, i.e. the state of the system at :math:`t = 0`. This will additionally cause the system state to be reset to :math:`t = 0` and the new initial conditions. :param vals: New initial values. """ - for k, v in vals.items(): - k = str(k) - assert k in self.initial_values.keys(), "Tried to set initial value for unknown parameter \"" + str(k) + "\"" - expr = sympy.parsing.sympy_parser.parse_expr(str(v), global_dict=Shape._sympy_globals) + for sym, expr in vals.items(): + if type(sym) is str: + sym = sympy.Symbol(sym, real=True) + + assert sym in self.initial_values.keys(), "Tried to set initial value for unknown parameter \"" + str(k) + "\"" + + if type(expr) is str: + expr = _sympy_parse_real(expr, global_dict=Shape._sympy_globals) + subs_dict = {} if "parameters" in self.solver_dict.keys(): - for param_symbol, param_val in self.solver_dict["parameters"].items(): + for param_name, param_val in self.solver_dict["parameters"].items(): + param_symbol = sympy.Symbol(param_name, real=True) subs_dict[param_symbol] = param_val + try: - self.initial_values[k] = float(expr.evalf(subs=subs_dict)) + if type(expr) is float: + self.initial_values[sym] = expr + else: + self.initial_values[sym] = float(expr.evalf(subs=subs_dict)) except TypeError: msg = "Could not convert initial value expression to float. The following symbol(s) may be undeclared: " + ", ".join([str(expr_) for expr_ in expr.evalf(subs=subs_dict).free_symbols]) raise Exception(msg) - self.reset() + self.reset() - def _update_step(self, delta_t, initial_values): + def _update_step(self, delta_t, initial_values) -> Dict[sympy.Symbol, sympy.Expr]: r""" Apply propagator to update the state, starting from `initial_values`, by timestep `delta_t`. :param delta_t: Timestep to take. :param initial_values: A dictionary mapping variable names (as strings) to initial value expressions. + :return new_state: A dictionary mapping symbols to state values. """ new_state = {} @@ -178,24 +285,23 @@ def _update_step(self, delta_t, initial_values): # replace expressions by their numeric values # - y = [delta_t] + [initial_values[str(sym)] for sym in self.all_variable_symbols] - + y = [delta_t] + [initial_values[sym] for sym in self.all_variable_symbols] # # for each state variable, perform the state update # for state_variable, expr in self.update_expressions.items(): - new_state[state_variable] = self.update_expressions_wrapped[state_variable](*y) + new_state[sympy.Symbol(state_variable, real=True)] = self.update_expressions_wrapped[state_variable](*y) return new_state - - def get_value(self, t): + def get_value(self, t: float) -> Dict[sympy.Symbol, sympy.Expr]: r""" Get numerical solution of the dynamical system at time :python:`t`. :param t: The time to compute the solution for. + :return state: A dictionary mapping symbols to state values. """ if (not self.enable_caching) \ @@ -211,7 +317,6 @@ def get_value(self, t): all_spike_times, all_spike_times_sym = self.get_sorted_spike_times() - # # process spikes between ⟨t_curr, t] # @@ -224,7 +329,6 @@ def get_value(self, t): if spike_t > t: break - # # apply propagator to update the state from `t_curr` to `spike_t` # @@ -234,7 +338,6 @@ def get_value(self, t): state_at_t_curr = self._update_step(delta_t, state_at_t_curr) t_curr = spike_t - # # delta impulse increment # @@ -243,7 +346,6 @@ def get_value(self, t): if spike_sym in self.initial_values.keys(): state_at_t_curr[spike_sym] += self.shape_starting_values[spike_sym] - # # update cache with the value at the last spike time (if we update with the value at the last requested time (`t`), we would accumulate roundoff errors) # @@ -252,7 +354,6 @@ def get_value(self, t): self.t_curr = t_curr self.state_at_t_curr = state_at_t_curr - # # apply propagator to update the state from `t_curr` to `t` # diff --git a/odetoolbox/integrator.py b/odetoolbox/integrator.py index 7927ae96..afcecf94 100644 --- a/odetoolbox/integrator.py +++ b/odetoolbox/integrator.py @@ -31,7 +31,9 @@ class Integrator: Integrate a dynamical system by means of the propagators returned by ODE-toolbox (base class). """ - all_variable_symbols = [] # type: List[sympy.Symbol] + all_variable_symbols: List[sympy.Symbol] = [] + all_spike_times: List[float] = [] + all_spike_times_sym: List[List[sympy.Symbol]] = [] def set_spike_times(self, spike_times: Optional[Dict[str, List[float]]]): r""" @@ -39,13 +41,16 @@ def set_spike_times(self, spike_times: Optional[Dict[str, List[float]]]): :param spike_times: For each variable, used as a key, the list of spike times associated with it. """ + if spike_times is None: self.spike_times = {} else: self.spike_times = spike_times.copy() + assert all([type(sym) is str for sym in self.spike_times.keys()]), "Spike time keys need to be of type str" - self.all_spike_times = [] # type: List[float] - self.all_spike_times_sym = [] # type: List[List[str]] + + self.all_spike_times = [] + self.all_spike_times_sym = [] for sym, sym_spike_times in self.spike_times.items(): assert type(sym) is str assert str(sym) in [str(_sym) for _sym in self.all_variable_symbols], "Tried to set a spike time of unknown symbol \"" + sym + "\"" @@ -59,8 +64,7 @@ def set_spike_times(self, spike_times: Optional[Dict[str, List[float]]]): idx = np.argsort(self.all_spike_times) self.all_spike_times = [self.all_spike_times[i] for i in idx] - self.all_spike_times_sym = [self.all_spike_times_sym[i] for i in idx] - + self.all_spike_times_sym = [[sympy.Symbol(sym, real=True) for sym in self.all_spike_times_sym[i]] for i in idx] def get_spike_times(self): r""" @@ -70,7 +74,6 @@ def get_spike_times(self): """ return self.spike_times - def get_sorted_spike_times(self): r""" Returns a global, sorted list of spike times. diff --git a/odetoolbox/mixed_integrator.py b/odetoolbox/mixed_integrator.py index 9509d7f6..c5f53492 100644 --- a/odetoolbox/mixed_integrator.py +++ b/odetoolbox/mixed_integrator.py @@ -19,24 +19,23 @@ # along with NEST. If not, see . # -from typing import Optional +from typing import Iterable, Optional import logging import numpy as np -import numpy.random import os import sympy import sympy.utilities.autowrap from sympy.utilities.autowrap import CodeGenArgumentListError import time - from .analytic_integrator import AnalyticIntegrator from .config import Config from .integrator import Integrator from .plot_helper import import_matplotlib from .shapes import Shape -from .sympy_helpers import _is_sympy_type +from .system_of_shapes import SystemOfShapes +from .sympy_helpers import _is_sympy_type, _sympy_parse_real try: import pygsl.odeiv as odeiv @@ -59,7 +58,7 @@ class MixedIntegrator(Integrator): Mixed numeric+analytic integrator. Supply with a result from ODE-toolbox analysis; calculates numeric approximation of the solution. """ - def __init__(self, numeric_integrator, system_of_shapes, shapes, analytic_solver_dict=None, parameters=None, spike_times=None, random_seed=123, max_step_size=np.inf, integration_accuracy_abs=1E-6, integration_accuracy_rel=1E-6, sim_time=1., alias_spikes=False, debug_plot_dir: Optional[str] = None): + def __init__(self, numeric_integrator, system_of_shapes: SystemOfShapes, shapes: Iterable[Shape], analytic_solver_dict=None, parameters=None, spike_times=None, random_seed=123, max_step_size=np.inf, integration_accuracy_abs=1E-6, integration_accuracy_rel=1E-6, sim_time=1., alias_spikes=False, debug_plot_dir: Optional[str] = None): r""" :param numeric_integrator: A method from the GSL library for evolving ODEs, e.g. :python:`odeiv.step_rk4` :param system_of_shapes: Dynamical system to solve. @@ -93,8 +92,10 @@ def __init__(self, numeric_integrator, system_of_shapes, shapes, analytic_solver self._parameters = {} else: self._parameters = parameters - self._parameters = {k: sympy.parsing.sympy_parser.parse_expr(v, global_dict=Shape._sympy_globals).n() if not _is_sympy_type(v) else v for k, v in self._parameters.items()} - self._locals = self._parameters.copy() + self._parameters = {k: _sympy_parse_real(v, global_dict=Shape._sympy_globals).n() if not _is_sympy_type(v) else v for k, v in self._parameters.items()} + self._locals = {} + for parameter_name, parameter_expr in self._parameters.items(): + self._locals[sympy.Symbol(parameter_name, real=True)] = parameter_expr self.random_seed = random_seed self.analytic_solver_dict = analytic_solver_dict @@ -109,7 +110,7 @@ def __init__(self, numeric_integrator, system_of_shapes, shapes, analytic_solver self.all_variable_symbols = list(self._system_of_shapes.x_) if not self.analytic_solver_dict is None: self.all_variable_symbols += self.analytic_solver_dict["state_variables"] - self.all_variable_symbols = [sympy.Symbol(str(sym).replace("'", Config().differential_order_symbol)) for sym in self.all_variable_symbols] + self.all_variable_symbols = [sympy.Symbol(str(sym).replace("'", Config().differential_order_symbol), real=True) for sym in self.all_variable_symbols] for sym, expr in self._update_expr.items(): try: @@ -127,14 +128,12 @@ def __init__(self, numeric_integrator, system_of_shapes, shapes, analytic_solver backend="cython", helpers=Shape._sympy_autowrap_helpers) - # # make a sorted list of all spike times for all symbols # self.set_spike_times(spike_times) - def integrate_ode(self, initial_values=None, h_min_lower_bound=5E-9, raise_errors=True, debug=False): r""" This function computes the average step size and the minimal step size that a given integration method from GSL uses to evolve a certain system of ODEs during a certain simulation time, integration method from GSL and spike train. @@ -157,7 +156,6 @@ def integrate_ode(self, initial_values=None, h_min_lower_bound=5E-9, raise_error all_spike_times, all_spike_times_sym = self.get_sorted_spike_times() - # # initialise analytic integrator # @@ -168,14 +166,13 @@ def integrate_ode(self, initial_values=None, h_min_lower_bound=5E-9, raise_error analytic_integrator_initial_values = {sym: iv for sym, iv in initial_values.items() if sym in self.analytic_integrator.get_all_variable_symbols()} self.analytic_integrator.set_initial_values(analytic_integrator_initial_values) - # # convert initial value expressions to floats # for sym in self._system_of_shapes.x_: if not sym in initial_values.keys(): - initial_values[sym] = float(self._system_of_shapes.get_initial_value(str(sym)).evalf(subs=self._parameters)) + initial_values[sym] = float(self._system_of_shapes.get_initial_value(str(sym)).evalf(subs={sympy.Symbol(sym, real=True): v for sym, v in self._parameters.items()})) upper_bound_crossed = False y = np.array([initial_values[sym] for sym in self._system_of_shapes.x_]) @@ -189,7 +186,6 @@ def integrate_ode(self, initial_values=None, h_min_lower_bound=5E-9, raise_error control = odeiv.control_y_new(gsl_stepper, self.integration_accuracy_abs, self.integration_accuracy_rel) evolve = odeiv.evolve(gsl_stepper, control, len(y)) - # # make NumPy warnings errors. Without this, we can't catch overflow errors that can occur in the step() function, which might indicate a problem with the ODE, the grid resolution or the stiffness testing framework itself. # @@ -205,7 +201,6 @@ def integrate_ode(self, initial_values=None, h_min_lower_bound=5E-9, raise_error time_start = time.time() - # # main loop # @@ -275,7 +270,6 @@ def integrate_ode(self, initial_values=None, h_min_lower_bound=5E-9, raise_error h_sum += h_suggested n_timesteps_taken += 1 - # # enforce bounds/thresholds # @@ -288,17 +282,15 @@ def integrate_ode(self, initial_values=None, h_min_lower_bound=5E-9, raise_error upper_bound_crossed = True y[idx] = initial_values[shape.symbol] - # # evaluate to numeric values those ODEs that are solved analytically # - self._locals.update({str(sym): y[i] for i, sym in enumerate(self._system_of_shapes.x_)}) + self._locals.update({sym: y[i] for i, sym in enumerate(self._system_of_shapes.x_)}) if not self.analytic_integrator is None: self._locals.update(self.analytic_integrator.get_value(t)) - # # apply the spikes, i.e. add the "initial values" to the system dynamical state vector # @@ -317,8 +309,8 @@ def integrate_ode(self, initial_values=None, h_min_lower_bound=5E-9, raise_error while t_next_spike <= t: syms_next_spike = all_spike_times_sym[idx_next_spike] for sym in syms_next_spike: - if sym in [str(sym_) for sym_ in self._system_of_shapes.x_]: - idx = [str(sym_) for sym_ in list(self._system_of_shapes.x_)].index(sym) + if str(sym) in [str(sym_) for sym_ in self._system_of_shapes.x_]: + idx = [str(sym_) for sym_ in list(self._system_of_shapes.x_)].index(str(sym)) y[idx] += float(self._system_of_shapes.get_initial_value(sym).evalf(subs=self._locals)) idx_next_spike += 1 if idx_next_spike < len(all_spike_times): @@ -327,8 +319,8 @@ def integrate_ode(self, initial_values=None, h_min_lower_bound=5E-9, raise_error t_next_spike = np.inf else: for sym in syms_next_spike: - if sym in [str(sym_) for sym_ in self._system_of_shapes.x_]: - idx = [str(sym_) for sym_ in list(self._system_of_shapes.x_)].index(sym) + if str(sym) in [str(sym_) for sym_ in self._system_of_shapes.x_]: + idx = [str(sym_) for sym_ in list(self._system_of_shapes.x_)].index(str(sym)) y[idx] += float(self._system_of_shapes.get_initial_value(sym).evalf(subs=self._locals)) h_avg = h_sum / n_timesteps_taken @@ -354,7 +346,6 @@ def integrate_ode(self, initial_values=None, h_min_lower_bound=5E-9, raise_error else: return h_min, h_avg, runtime - def integrator_debug_plot(self, t_log, h_log, y_log, dir): mpl, plt = import_matplotlib() assert mpl, "Debug plot was requested for MixedIntegrator, but an exception occurred while importing matplotlib. See the ``debug_plot_dir`` parameter." @@ -413,7 +404,6 @@ def integrator_debug_plot(self, t_log, h_log, y_log, dir): plt.savefig(fn, dpi=600) plt.close(fig) - def numerical_jacobian(self, t, y, params): r""" Compute the numerical values of the Jacobian matrix at the current time :python:`t` and state :python:`y`. @@ -430,12 +420,12 @@ def numerical_jacobian(self, t, y, params): dfdy = np.zeros((dimension, dimension), float) dfdt = np.zeros((dimension,)) - self._locals.update({str(sym): y[i] for i, sym in enumerate(self._system_of_shapes.x_)}) + self._locals.update({sym: y[i] for i, sym in enumerate(self._system_of_shapes.x_)}) if not self.analytic_integrator is None: self._locals.update(self.analytic_integrator.get_value(t)) - y = [self._locals[str(sym)] for sym in self.all_variable_symbols] + y = [self._locals[sym] for sym in self.all_variable_symbols] for row in range(0, dimension): for col in range(0, dimension): @@ -444,7 +434,6 @@ def numerical_jacobian(self, t, y, params): return dfdy, dfdt - def step(self, t, y, params): r""" "Stepping function": compute the (numerical) value of the derivative of :python:`y` over time, at the current time :python:`t` and state :python:`y`. @@ -455,7 +444,7 @@ def step(self, t, y, params): :return: Updated state vector """ - self._locals.update({str(sym): y[i] for i, sym in enumerate(self._system_of_shapes.x_)}) + self._locals.update({sym: y[i] for i, sym in enumerate(self._system_of_shapes.x_)}) # # update state of analytically solved variables to time `t` @@ -465,7 +454,7 @@ def step(self, t, y, params): self._locals.update(self.analytic_integrator.get_value(t)) # y holds the state of all the symbols in the numeric part of the system; add those for the analytic part - y = [self._locals[str(sym)] for sym in self.all_variable_symbols] + y = [self._locals[sym] for sym in self.all_variable_symbols] try: # return [ float(self._update_expr[str(sym)].evalf(subs=self._locals)) for sym in self._system_of_shapes.x_ ] # non-wrapped version diff --git a/odetoolbox/shapes.py b/odetoolbox/shapes.py index 0cbf3ef6..b22d16e3 100644 --- a/odetoolbox/shapes.py +++ b/odetoolbox/shapes.py @@ -32,7 +32,7 @@ from sympy.core.expr import Expr as SympyExpr from .config import Config -from .sympy_helpers import _check_numerical_issue, _check_forbidden_name, _custom_simplify_expr, _is_constant_term, _is_sympy_type, _is_zero +from .sympy_helpers import _check_numerical_issue, _check_forbidden_name, _custom_simplify_expr, _is_constant_term, _is_sympy_type, _is_zero, _sympy_parse_real class MalformedInputException(Exception): @@ -88,15 +88,15 @@ class Shape: "Heaviside": sympy.Heaviside, "e": sympy.exp(1), "E": sympy.exp(1), - "t": sympy.Symbol("t"), + "t": sympy.Symbol("t", real=True), "DiracDelta": sympy.DiracDelta} # cython backend (used by sympy autowrap()) cannot handle these functions; need to provide alternative implementation - _sympy_autowrap_helpers = [("Min", (abs(sympy.symbols("x") + sympy.symbols("y")) - abs(sympy.symbols("x") - sympy.symbols("y"))) / 2, [sympy.symbols("x"), sympy.symbols("y")]), - ("Max", (abs(sympy.symbols("x") + sympy.symbols("y")) + abs(sympy.symbols("x") - sympy.symbols("y"))) / 2, [sympy.symbols("x"), sympy.symbols("y")]), - ("Heaviside", (sympy.symbols("x") + abs(sympy.symbols("x"))) / (2 * abs(sympy.symbols("x")) + 1E-300), [sympy.symbols("x")])] + _sympy_autowrap_helpers = [("Min", (abs(sympy.symbols("x", real=True) + sympy.symbols("y", real=True)) - abs(sympy.symbols("x", real=True) - sympy.symbols("y", real=True))) / 2, [sympy.symbols("x", real=True), sympy.symbols("y", real=True)]), + ("Max", (abs(sympy.symbols("x", real=True) + sympy.symbols("y", real=True)) + abs(sympy.symbols("x", real=True) - sympy.symbols("y", real=True))) / 2, [sympy.symbols("x", real=True), sympy.symbols("y", real=True)]), + ("Heaviside", (sympy.symbols("x", real=True) + abs(sympy.symbols("x", real=True))) / (2 * abs(sympy.symbols("x", real=True)) + 1E-300), [sympy.symbols("x", real=True)])] - def __init__(self, symbol, order, initial_values, derivative_factors, inhom_term=sympy.Float(0.), nonlin_term=sympy.Float(0.), lower_bound=None, upper_bound=None): + def __init__(self, symbol: sympy.Symbol, order, initial_values, derivative_factors, inhom_term=sympy.Float(0.), nonlin_term=sympy.Float(0.), lower_bound=None, upper_bound=None): r""" Perform type and consistency checks and assign arguments to member variables. @@ -110,6 +110,18 @@ def __init__(self, symbol, order, initial_values, derivative_factors, inhom_term if not type(symbol) is sympy.Symbol: raise MalformedInputException("symbol is not a SymPy symbol: \"%r\"" % symbol) + assert symbol.is_real + + for derivative_factor in derivative_factors: + for _sym in derivative_factor.free_symbols: + assert _sym.is_real + + for _sym in inhom_term.free_symbols: + assert _sym.is_real + + for _sym in nonlin_term.free_symbols: + assert _sym.is_real + self.symbol = symbol if str(symbol) in Shape._sympy_globals.keys(): @@ -154,27 +166,30 @@ def __init__(self, symbol, order, initial_values, derivative_factors, inhom_term self.lower_bound = lower_bound if not self.lower_bound is None: + if type(self.lower_bound) is str: + self.lower_bound = _sympy_parse_real(self.lower_bound) + self.lower_bound = _custom_simplify_expr(self.lower_bound) self.upper_bound = upper_bound if not self.upper_bound is None: + if type(self.upper_bound) is str: + self.upper_bound = _sympy_parse_real(self.upper_bound) + self.upper_bound = _custom_simplify_expr(self.upper_bound) logging.debug("Created Shape with symbol " + str(self.symbol) + ", derivative_factors = " + str(self.derivative_factors) + ", inhom_term = " + str(self.inhom_term) + ", nonlin_term = " + str(self.nonlin_term)) - def __str__(self): s = "Shape \"" + str(self.symbol) + "\" of order " + str(self.order) return s - def is_homogeneous(self) -> bool: r""" :return: :python:`False` if and only if the shape has a nonzero right-hand side. """ return _is_zero(self.inhom_term) - def get_initial_value(self, sym: str): r""" Get the initial value corresponding to the variable symbol. @@ -185,7 +200,6 @@ def get_initial_value(self, sym: str): return None return self.initial_values[sym] - def get_state_variables(self, derivative_symbol="'") -> List[sympy.Symbol]: r""" Get all variable symbols for this shape, ordered according to derivative order, up to the shape's order :math:`N`: :python:`[sym, dsym/dt, d^2sym/dt^2, ..., d^(N-1)sym/dt^(N-1)]` @@ -193,11 +207,10 @@ def get_state_variables(self, derivative_symbol="'") -> List[sympy.Symbol]: all_symbols = [] for order in range(self.order): - all_symbols.append(sympy.Symbol(str(self.symbol) + derivative_symbol * order)) + all_symbols.append(sympy.Symbol(str(self.symbol) + derivative_symbol * order, real=True)) return all_symbols - def get_all_variable_symbols(self, shapes=None, derivative_symbol="'") -> List[sympy.Symbol]: r""" Get all variable symbols for this shape and all other shapes in :python:`shapes`, without duplicates, in no particular order. @@ -220,14 +233,12 @@ def get_all_variable_symbols(self, shapes=None, derivative_symbol="'") -> List[s return all_symbols - def is_lin_const_coeff(self) -> bool: r""" :return: :python:`True` if and only if the shape is linear and constant coefficient. """ return _is_zero(self.nonlin_term) - def is_lin_const_coeff_in(self, symbols, parameters=None): r""" :return: :python:`True` if and only if the shape is linear and constant coefficient in those variables passed in ``symbols``. @@ -236,7 +247,6 @@ def is_lin_const_coeff_in(self, symbols, parameters=None): derivative_factors, inhom_term, nonlin_term = Shape.split_lin_inhom_nonlin(expr, symbols, parameters=parameters) return _is_zero(nonlin_term) - @classmethod def _parse_defining_expression(cls, s: str) -> Tuple[str, int, str]: r"""Parse a defining expression, for example, if the ODE-toolbox JSON input file contains the snippet: @@ -265,7 +275,6 @@ def _parse_defining_expression(cls, s: str) -> Tuple[str, int, str]: order = len(re.findall("'", lhs)) return symbol, order, rhs - @classmethod def from_json(cls, indict, all_variable_symbols=None, parameters=None, _debug=False): r""" @@ -337,7 +346,6 @@ def from_json(cls, indict, all_variable_symbols=None, parameters=None, _debug=Fa return Shape.from_ode(symbol, rhs, initial_values, all_variable_symbols=all_variable_symbols, lower_bound=lower_bound, upper_bound=upper_bound, parameters=parameters) - def reconstitute_expr(self) -> SympyExpr: r""" Recreate right-hand side expression from internal representation (linear coefficients, inhomogeneous, and nonlinear parts). @@ -346,10 +354,15 @@ def reconstitute_expr(self) -> SympyExpr: derivative_symbols = self.get_state_variables(derivative_symbol=Config().differential_order_symbol) for derivative_factor, derivative_symbol in zip(self.derivative_factors, derivative_symbols): expr += derivative_factor * derivative_symbol + for sym in derivative_factor.free_symbols: + assert sym.is_real + for sym in derivative_symbol.free_symbols: + assert sym.is_real logging.debug("Shape " + str(self.symbol) + ": reconstituting expression " + str(expr)) + for sym in expr.free_symbols: + assert sym.is_real return expr - @staticmethod def split_lin_inhom_nonlin(expr, x, parameters=None): r""" @@ -401,8 +414,13 @@ def split_lin_inhom_nonlin(expr, x, parameters=None): logging.debug("\tinhomogeneous term: " + str(inhom_term)) logging.debug("\tnonlinear term: " + str(nonlin_term)) - return lin_factors, inhom_term, nonlin_term + for sym in nonlin_term.free_symbols: + assert sym.is_real + for sym in inhom_term.free_symbols: + assert sym.is_real + + return lin_factors, inhom_term, nonlin_term @classmethod def from_function(cls, symbol: str, definition, max_t=100, max_order=4, all_variable_symbols=None, debug=False) -> Shape: @@ -429,21 +447,20 @@ def from_function(cls, symbol: str, definition, max_t=100, max_order=4, all_vari all_variable_symbols_dict = {str(el): el for el in all_variable_symbols} - definition = sympy.parsing.sympy_parser.parse_expr(definition, global_dict=Shape._sympy_globals, local_dict=all_variable_symbols_dict) + definition = _sympy_parse_real(definition, global_dict=Shape._sympy_globals, local_dict=all_variable_symbols_dict) - # `derivatives` is a list of all derivatives of `shape` up to the order we are checking, starting at 0. - derivatives = [definition, sympy.diff(definition, Config().input_time_symbol)] + # ``derivatives`` is a list of all derivatives of `shape` up to the order we are checking, starting at 0. + derivatives = [definition, sympy.diff(definition, sympy.Symbol(Config().input_time_symbol, real=True))] logging.debug("Processing function-of-time shape \"" + symbol + "\" with defining expression = \"" + str(definition) + "\"") - # # to avoid a division by zero below, we have to find a `t` so that the shape function is not zero at this `t`. # t_val = None for t_ in range(0, max_t): - if not _is_zero(definition.subs(Config().input_time_symbol, t_)): + if not _is_zero(definition.subs(sympy.Symbol(Config().input_time_symbol, real=True), t_)): t_val = t_ break @@ -458,20 +475,18 @@ def from_function(cls, symbol: str, definition, max_t=100, max_order=4, all_vari msg = "Cannot find t for which shape function is unequal to zero" raise Exception(msg) - # # first handle the case for an ODE of order 1, i.e. of the form I' = a0 * I # order = 1 - logging.debug("\tFinding ode for order 1...") + logging.debug("\tFinding ODE of order 1...") - derivative_factors = [(1 / derivatives[0] * derivatives[1]).subs(Config().input_time_symbol, t_val)] + derivative_factors = [(1 / derivatives[0] * derivatives[1]).subs(sympy.Symbol(Config().input_time_symbol, real=True), t_val)] diff_rhs_lhs = derivatives[1] - derivative_factors[0] * derivatives[0] found_ode = _is_zero(diff_rhs_lhs) - # # If `shape` does not satisfy a linear homogeneous ODE of order 1, we try to find one of higher order in a loop. The loop runs while no linear homogeneous ODE was found and the maximum order to check for was not yet reached. # @@ -479,10 +494,10 @@ def from_function(cls, symbol: str, definition, max_t=100, max_order=4, all_vari while not found_ode and order < max_order: order += 1 - logging.debug("\tFinding ode for order " + str(order) + "...") + logging.debug("\tFinding ODE of order " + str(order) + "...") # Add the next higher derivative to the list - derivatives.append(sympy.diff(derivatives[-1], Config().input_time_symbol)) + derivatives.append(sympy.diff(derivatives[-1], sympy.Symbol(Config().input_time_symbol, real=True))) X = sympy.zeros(order) @@ -494,9 +509,9 @@ def from_function(cls, symbol: str, definition, max_t=100, max_order=4, all_vari for t_ in range(1, max_t): for i in range(order): substitute = i + t_ - Y[i] = derivatives[order].subs(Config().input_time_symbol, substitute) + Y[i] = derivatives[order].subs(sympy.Symbol(Config().input_time_symbol, real=True), substitute) for j in range(order): - X[i, j] = derivatives[j].subs(Config().input_time_symbol, substitute) + X[i, j] = derivatives[j].subs(sympy.Symbol(Config().input_time_symbol, real=True), substitute) if not _is_zero(sympy.det(X)): invertible = True @@ -509,14 +524,12 @@ def from_function(cls, symbol: str, definition, max_t=100, max_order=4, all_vari if not invertible: continue - # # calculate `derivative_factors` # derivative_factors = sympy.simplify(X.inv() * Y) # XXX: need sympy.simplify() here rather than _custom_simplify_expr() - # # fill in the obtained expressions for the derivative_factors and check whether they satisfy the definition of the shape # @@ -540,10 +553,9 @@ def from_function(cls, symbol: str, definition, max_t=100, max_order=4, all_vari # calculate the initial values of the found ODE # - initial_values = {symbol + derivative_order * "'": x.subs(Config().input_time_symbol, 0) for derivative_order, x in enumerate(derivatives[:-1])} - - return cls(sympy.Symbol(symbol), order, initial_values, derivative_factors) + initial_values = {symbol + derivative_order * "'": x.subs(sympy.Symbol(Config().input_time_symbol, real=True), 0) for derivative_order, x in enumerate(derivatives[:-1])} + return cls(sympy.Symbol(symbol, real=True), order, initial_values, derivative_factors) @classmethod def from_ode(cls, symbol: str, definition: str, initial_values: dict, all_variable_symbols=None, lower_bound=None, upper_bound=None, parameters=None, debug=False, **kwargs) -> Shape: @@ -567,6 +579,9 @@ def from_ode(cls, symbol: str, definition: str, initial_values: dict, all_variab assert type(symbol) is str assert type(definition) is str assert type(initial_values) is dict + if all_variable_symbols: + for sym in all_variable_symbols: + assert sym.is_real logging.debug("\nProcessing differential-equation form shape " + str(symbol) + " with defining expression = \"" + str(definition) + "\"") @@ -575,16 +590,20 @@ def from_ode(cls, symbol: str, definition: str, initial_values: dict, all_variab order: int = len(initial_values) all_variable_symbols_dict = {str(el): el for el in all_variable_symbols} - definition = sympy.parsing.sympy_parser.parse_expr(definition.replace("'", Config().differential_order_symbol), global_dict=Shape._sympy_globals, local_dict=all_variable_symbols_dict) # minimal global_dict to make no assumptions (e.g. "beta" could otherwise be recognised as a function instead of as a parameter symbol) + if parameters: + parameter_symbols_real_dict = {str(sym): sym for sym in parameters.keys()} # make a dict with a symbol for each parameter to make sure we can set the domain to Real + else: + parameter_symbols_real_dict = {} + definition = _sympy_parse_real(definition.replace("'", Config().differential_order_symbol), global_dict=Shape._sympy_globals, local_dict=all_variable_symbols_dict | parameter_symbols_real_dict) # minimal global_dict to make no assumptions (e.g. "beta" could otherwise be recognised as a function instead of as a parameter symbol) # validate input for forbidden names - _initial_values = {k: sympy.parsing.sympy_parser.parse_expr(v, evaluate=False, global_dict=Shape._sympy_globals, local_dict=all_variable_symbols_dict) for k, v in initial_values.items()} + _initial_values = {k: _sympy_parse_real(v, evaluate=False, global_dict=Shape._sympy_globals, local_dict=all_variable_symbols_dict) for k, v in initial_values.items()} for iv_expr in _initial_values.values(): for var in iv_expr.atoms(): _check_forbidden_name(var) # parse input - initial_values = {k: sympy.parsing.sympy_parser.parse_expr(v, global_dict=Shape._sympy_globals, local_dict=all_variable_symbols_dict) for k, v in initial_values.items()} + initial_values = {k: _sympy_parse_real(v, global_dict=Shape._sympy_globals, local_dict=all_variable_symbols_dict) for k, v in initial_values.items()} # validate input for numerical issues for iv_expr in initial_values.values(): @@ -592,11 +611,11 @@ def from_ode(cls, symbol: str, definition: str, initial_values: dict, all_variab _check_numerical_issue(var) local_symbols = [symbol + Config().differential_order_symbol * i for i in range(order)] - local_symbols_sympy = [sympy.Symbol(sym_name) for sym_name in local_symbols] + local_symbols_sympy = [sympy.Symbol(sym_name, real=True) for sym_name in local_symbols] if not symbol in all_variable_symbols: all_variable_symbols.extend(local_symbols_sympy) all_variable_symbols = [str(sym_name).replace("'", Config().differential_order_symbol) for sym_name in all_variable_symbols] - all_variable_symbols_sympy = [sympy.Symbol(sym_name) for sym_name in all_variable_symbols] + all_variable_symbols_sympy = [sympy.Symbol(sym_name, real=True) for sym_name in all_variable_symbols] derivative_factors, inhom_term, nonlin_term = Shape.split_lin_inhom_nonlin(definition, all_variable_symbols_sympy, parameters=parameters) local_symbols_idx = [all_variable_symbols.index(sym) for sym in local_symbols] local_derivative_factors = [derivative_factors[i] for i in local_symbols_idx] @@ -604,7 +623,8 @@ def from_ode(cls, symbol: str, definition: str, initial_values: dict, all_variab if nonlocal_derivative_terms: nonlin_term = nonlin_term + functools.reduce(lambda x, y: x + y, nonlocal_derivative_terms) - shape = cls(sympy.Symbol(symbol), order, initial_values, local_derivative_factors, inhom_term, nonlin_term, lower_bound, upper_bound) + sym = sympy.Symbol(symbol, real=True) + shape = cls(sym, order, initial_values, local_derivative_factors, inhom_term, nonlin_term, lower_bound, upper_bound) logging.debug("\tReturning shape: " + str(shape)) return shape diff --git a/odetoolbox/singularity_detection.py b/odetoolbox/singularity_detection.py index 98013bdb..2699e8e8 100644 --- a/odetoolbox/singularity_detection.py +++ b/odetoolbox/singularity_detection.py @@ -18,13 +18,13 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . # -from typing import Dict, List, Set +from typing import Dict, List, Set, Union import logging import sympy import sympy.parsing.sympy_parser -from odetoolbox.sympy_helpers import SymmetricEq, symbol_in_expression +from .sympy_helpers import SymmetricEq, _custom_simplify_expr, symbol_in_expression, _is_zero class SingularityDetectionException(Exception): @@ -89,7 +89,7 @@ class SingularityDetection: """ @staticmethod - def _is_matrix_defined_under_substitution(A: sympy.Matrix, cond_set: Set[SymmetricEq]) -> bool: + def _is_matrix_defined_under_substitution(A: sympy.Matrix, cond: Union[SymmetricEq, Set[SymmetricEq]]) -> bool: r""" Function to check if a matrix is defined (i.e. does not contain NaN or infinity) after we perform a given set of subsitutions. @@ -101,9 +101,15 @@ def _is_matrix_defined_under_substitution(A: sympy.Matrix, cond_set: Set[Symmetr a set with equations, where the left-hand side of each equation is the variable that is to be subsituted, and the right-hand side is the expression to put in its place """ for val in sympy.flatten(A): + if isinstance(val, float) or isinstance(val, int) or isinstance(val, sympy.core.numbers.Number): + continue + expr_sub = val.copy() - for eq in cond_set: - expr_sub = expr_sub.subs(eq.lhs, eq.rhs) + if isinstance(cond, set): + for _cond in cond: + expr_sub = expr_sub.subs(_cond.lhs, _cond.rhs) + else: + expr_sub = expr_sub.subs(cond.lhs, cond.rhs) if symbol_in_expression([sympy.nan, sympy.zoo, sympy.oo], sympy.simplify(expr_sub)): return False @@ -111,56 +117,48 @@ def _is_matrix_defined_under_substitution(A: sympy.Matrix, cond_set: Set[Symmetr return True @staticmethod - def _filter_valid_conditions(conds, A: sympy.Matrix): + def _filter_valid_conditions(conds: Set[Set[SymmetricEq]], A: sympy.Matrix): filt_cond = set() - for cond_set in conds: # looping over condition sets - if SingularityDetection._is_matrix_defined_under_substitution(A, cond_set): - filt_cond.add(cond_set) + for cond in conds: # looping over condition sets + if SingularityDetection._is_matrix_defined_under_substitution(A, cond): + filt_cond.add(cond) return filt_cond @staticmethod - def _generate_singularity_conditions(A: sympy.Matrix) -> List[Dict[sympy.core.expr.Expr, sympy.core.expr.Expr]]: + def _generate_singularity_conditions(P: sympy.Matrix) -> List[Dict[sympy.core.expr.Expr, sympy.core.expr.Expr]]: r""" The function solve returns a list where each element is a dictionary. And each dictionary entry (condition: expression) corresponds to a condition at which that expression goes to zero. If the expression is quadratic, like let's say "x**2-1" then the function 'solve() returns two dictionaries in a list. Each dictionary corresponds to one solution. We are then collecting these lists in our own list called ``conditions`` and return it. """ conditions = set() - for expr in sympy.flatten(A): - for subexpr in sympy.preorder_traversal(expr): # traversing through the tree - if isinstance(subexpr, sympy.Pow) and subexpr.args[1] < 0: # find expressions of the form 1/x, which is encoded in sympy as x^-1 - denom = subexpr.args[0] # extracting the denominator - symbols = list(denom.free_symbols) - conds = SingularityDetection.find_singularity_conditions_in_expression_(denom, symbols) - conditions = conditions.union(conds) + for expr in sympy.flatten(P): + conds = SingularityDetection.find_singularity_conditions_in_expression_(expr) + conditions = conditions.union(conds) return conditions @staticmethod - def find_singularity_conditions_in_expression_(expr, symbols) -> Set[SymmetricEq]: - # find all conditions under which the denominator goes to zero. Each element of the returned list contains a particular combination of conditions for which A[row, row] goes to zero. For instance: ``solve([x - 3, y**2 - 1])`` returns ``[{x: 3, y: -1}, {x: 3, y: 1}]`` - conditions = sympy.solve(expr, symbols, dict=True, domain=sympy.S.Reals) - - # remove solutions that contain the imaginary number. ``domain=sympy.S.Reals`` does not seem to work perfectly as an argument to sympy.solve(), while sympy's ``reduce_inequalities()`` only supports univariate equations at the time of writing - accepted_conditions = [] - for cond_set in conditions: - i_in_expr = any([sympy.I in sympy.preorder_traversal(v) for v in cond_set.values()]) - if not i_in_expr: - accepted_conditions.append(cond_set) - - conditions = accepted_conditions + def find_singularity_conditions_in_expression_(expr: sympy.core.expr.Expr) -> Set[SymmetricEq]: + r"""Find conditions under which subterms of ``expr`` of the form ``a / b`` equal infinity (in general, when b = 0).""" + conditions = set() - # convert dictionaries to sympy equations - converted_conditions = set() - for cond_set in conditions: - cond_eqs_set = set([SymmetricEq(k, v) for k, v in cond_set.items()]) # convert to actual equations - converted_conditions.add(frozenset(cond_eqs_set)) + sub_expr_conds = set() + for symbol in expr.free_symbols: + assert symbol.is_real + singularity_set = sympy.calculus.singularities(expr, symbol=symbol, domain=sympy.Reals) + for singular_value in singularity_set: + singular_value = sympy.nsimplify(singular_value) # nsimplify is necessary to remove 1.0* factors + singular_condition = SymmetricEq(symbol, singular_value) + assert symbol.is_real + assert singular_value.is_real + sub_expr_conds.add(singular_condition) - conditions = converted_conditions + conditions = conditions.union(sub_expr_conds) return conditions @staticmethod - def find_inhomogeneous_singularities(expr) -> Set[SymmetricEq]: + def find_inhomogeneous_singularities(A: sympy.Matrix, b: sympy.Matrix) -> Set[SymmetricEq]: r"""Find singularities in the inhomogeneous part of the update equations. Returns @@ -169,19 +167,20 @@ def find_inhomogeneous_singularities(expr) -> Set[SymmetricEq]: conditions a set with equations, where the left-hand side of each equation is the variable that is to be subsituted, and the right-hand side is the expression to put in its place """ - logging.debug("Checking for singularities (divisions by zero) in the inhomogeneous part of the update equations...") + logging.debug("Checking for singularities due to inhomogeneous terms in the system of ODEs...") - symbols = list(expr.free_symbols) conditions = set() - if symbols: - conditions = SingularityDetection.find_singularity_conditions_in_expression_(expr, symbols) - if conditions: - # if there is one or more condition under which the solution goes to infinity... + for row in range(A.shape[0]): + if _is_zero(b[row]): + # this is not an inhomogeneous ODE + continue - logging.warning("Under certain conditions, one or more inhomogeneous term(s) in the system contain a division by zero.") - logging.warning("List of all conditions that result in a division by zero:") - for cond_set in conditions: - logging.warning("\t" + r" ∧ ".join([str(eq.lhs) + " = " + str(eq.rhs) for eq in cond_set])) + particular_solution = -b[row] / A[row, row] + particular_solution = sympy.simplify(particular_solution) # using _custom_update_expr() does not guarantee that an adequate number of cases will be covered + + conditions = conditions.union(SingularityDetection.find_singularity_conditions_in_expression_(particular_solution)) + + conditions = SingularityDetection._filter_valid_conditions(conditions, A) # filters out the invalid conditions (invalid means those for which A is not defined) return conditions @@ -203,12 +202,14 @@ def find_propagator_singularities(P: sympy.Matrix, A: sympy.Matrix) -> Set[Symme conditions a set with equations, where the left-hand side of each equation is the variable that is to be subsituted, and the right-hand side is the expression to put in its place """ - logging.debug("Checking for singularities (divisions by zero) in the propagator matrix...") - try: + logging.debug("Checking for singularities in the propagator matrix...") + # try: + if 1: conditions = SingularityDetection._generate_singularity_conditions(P) conditions = SingularityDetection._filter_valid_conditions(conditions, A) # filters out the invalid conditions (invalid means those for which A is not defined) - except Exception as e: - print(e) - raise SingularityDetectionException() + + # except Exception as e: + # print(e) + # raise SingularityDetectionException() return conditions diff --git a/odetoolbox/spike_generator.py b/odetoolbox/spike_generator.py index ca25112c..d0d83784 100644 --- a/odetoolbox/spike_generator.py +++ b/odetoolbox/spike_generator.py @@ -60,7 +60,6 @@ def spike_times_from_json(cls, stimuli, sim_time) -> Mapping[str, List[float]]: return spike_times - @classmethod def _generate_homogeneous_poisson_spikes(cls, T: float, rate: float, min_isi: float = 1E-6): r""" @@ -83,7 +82,6 @@ def _generate_homogeneous_poisson_spikes(cls, T: float, rate: float, min_isi: fl return spike_times - @classmethod def _generate_regular_spikes(cls, T: float, rate: float): r""" diff --git a/odetoolbox/stiffness.py b/odetoolbox/stiffness.py index 7de05369..764c50b4 100644 --- a/odetoolbox/stiffness.py +++ b/odetoolbox/stiffness.py @@ -24,6 +24,8 @@ import numpy.random import sympy +from odetoolbox.sympy_helpers import _sympy_parse_real + from .mixed_integrator import MixedIntegrator from .mixed_integrator import ParametersIncompleteException from .shapes import Shape @@ -68,7 +70,7 @@ def __init__(self, system_of_shapes, shapes, analytic_solver_dict=None, paramete self.parameters = {} else: self.parameters = parameters - self.parameters = {k: sympy.parsing.sympy_parser.parse_expr(v, global_dict=Shape._sympy_globals).n() for k, v in self.parameters.items()} + self.parameters = {k: _sympy_parse_real(v, global_dict=Shape._sympy_globals).n() for k, v in self.parameters.items()} self._locals = self.parameters.copy() if stimuli is None: self._stimuli = [] @@ -83,19 +85,16 @@ def __init__(self, system_of_shapes, shapes, analytic_solver_dict=None, paramete self.analytic_solver_dict["parameters"].update(self.parameters) self.analytic_integrator = None - @property def random_seed(self): return self._random_seed - @random_seed.setter def random_seed(self, value): assert type(value) is int assert value >= 0 self._random_seed = value - def check_stiffness(self, raise_errors=False): r""" Perform stiffness testing: use implicit and explicit solvers to simulate the dynamical system, then decide which is the better solver to use. @@ -118,7 +117,6 @@ def check_stiffness(self, raise_errors=False): return self._draw_decision(step_min_imp, step_min_exp, step_average_imp, step_average_exp) - def _evaluate_integrator(self, integrator, h_min_lower_bound=1E-12, raise_errors=True, debug=True): r""" This function computes the average step size and the minimal step size that a given integration method from GSL uses to evolve a certain system of ODEs during a certain simulation time, integration method from GSL and spike train for a given maximal stepsize. @@ -139,7 +137,6 @@ def _evaluate_integrator(self, integrator, h_min_lower_bound=1E-12, raise_errors spike_times = SpikeGenerator.spike_times_from_json(self._stimuli, self.sim_time) - # # initialise and run mixed integrator # @@ -164,7 +161,6 @@ def _evaluate_integrator(self, integrator, h_min_lower_bound=1E-12, raise_errors return h_min, h_avg, runtime - def _draw_decision(self, step_min_imp, step_min_exp, step_average_imp, step_average_exp, machine_precision_dist_ratio=10, avg_step_size_ratio=6): r""" Decide which is the best integrator to use. @@ -187,5 +183,5 @@ def _draw_decision(self, step_min_imp, step_min_exp, step_average_imp, step_aver if step_average_imp > avg_step_size_ratio * step_average_exp: return "implicit" - else: - return "explicit" + + return "explicit" diff --git a/odetoolbox/sympy_helpers.py b/odetoolbox/sympy_helpers.py index 9c1ac74c..dc7042e3 100644 --- a/odetoolbox/sympy_helpers.py +++ b/odetoolbox/sympy_helpers.py @@ -19,7 +19,7 @@ # along with NEST. If not, see . # -from typing import Mapping +from typing import Dict, Mapping, Optional import logging import sympy @@ -33,6 +33,37 @@ class NumericalIssueException(Exception): pass +def _sympy_parse_real(expr: str, global_dict: Optional[Dict] = None, local_dict: Optional[Dict] = None, evaluate: bool = True) -> sympy.core.expr.Expr: + r"""Custom parse function to make sure that all returned symbols have domain Real. + + Minimal global_dict to make no assumptions (e.g. "beta" could otherwise be recognised as a function instead of as a parameter symbol)""" + assert type(expr) is str + + if global_dict: + global_dict = global_dict.copy() + assert not "__builtins__" in global_dict.keys() + initial_parse = sympy.parsing.sympy_parser.parse_expr(expr, global_dict=global_dict, local_dict=local_dict, evaluate=evaluate) + + all_syms = initial_parse.free_symbols + if local_dict: + extended_local_dict = local_dict.copy() + else: + extended_local_dict = {} + + for sym in all_syms: + extended_local_dict_syms_as_str = [str(local_dict_sym) for local_dict_sym in extended_local_dict.keys()] + if not str(sym) in extended_local_dict_syms_as_str: + real_sym = sympy.Symbol(str(sym), real=True) + extended_local_dict[str(real_sym)] = real_sym + + final_parse = sympy.parsing.sympy_parser.parse_expr(expr, global_dict=global_dict, local_dict=extended_local_dict, evaluate=evaluate) + + for sym in final_parse.free_symbols: + assert sym.is_real + + return final_parse + + def _is_constant_term(term, parameters: Mapping[sympy.Symbol, str] = None) -> bool: r""" :return: :python:`True` if and only if this term contains only numerical values and parameters; :python:`False` otherwise. diff --git a/odetoolbox/system_of_shapes.py b/odetoolbox/system_of_shapes.py index 0ae0c07d..71716153 100644 --- a/odetoolbox/system_of_shapes.py +++ b/odetoolbox/system_of_shapes.py @@ -19,7 +19,8 @@ # along with NEST. If not, see . # -from typing import List, Optional +import itertools +from typing import List, Optional, Set, Union import logging import numpy as np @@ -32,7 +33,7 @@ from .config import Config from .shapes import Shape from .singularity_detection import SingularityDetection, SingularityDetectionException -from .sympy_helpers import _custom_simplify_expr, _is_zero +from .sympy_helpers import SymmetricEq, _custom_simplify_expr, _is_zero, _sympy_parse_real class GetBlockDiagonalException(Exception): @@ -96,21 +97,20 @@ def __init__(self, x: sympy.Matrix, A: sympy.Matrix, b: sympy.Matrix, c: sympy.M self.c_ = c self.shapes_ = shapes - - def get_shape_by_symbol(self, sym: str) -> Optional[Shape]: + def get_shape_by_symbol(self, sym: Union[str, sympy.Symbol]) -> Optional[Shape]: for shape in self.shapes_: - if str(shape.symbol) == sym: + if str(shape.symbol) == str(sym): return shape + return None - def get_initial_value(self, sym): + def get_initial_value(self, sym: Union[str, sympy.Symbol]): for shape in self.shapes_: if str(shape.symbol) == str(sym).replace(Config().differential_order_symbol, "").replace("'", ""): - return shape.get_initial_value(sym.replace(Config().differential_order_symbol, "'")) + return shape.get_initial_value(str(sym).replace(Config().differential_order_symbol, "'")) assert False, "Unknown symbol: " + str(sym) - def get_dependency_edges(self): E = [] @@ -121,7 +121,6 @@ def get_dependency_edges(self): return E - def get_lin_cc_symbols(self, E, parameters=None): r""" Retrieve the variable symbols of those shapes that are linear and constant coefficient. In the case of a higher-order shape, will return all the variable symbols with ``"__d"`` suffixes up to the order of the shape. @@ -141,7 +140,6 @@ def get_lin_cc_symbols(self, E, parameters=None): return node_is_lin - def propagate_lin_cc_judgements(self, node_is_lin, E): r""" Propagate: if a node depends on a node that is not linear and constant coefficient, it cannot be linear and constant coefficient. @@ -164,7 +162,6 @@ def propagate_lin_cc_judgements(self, node_is_lin, E): return node_is_lin - def get_jacobian_matrix(self): r""" Get the Jacobian matrix as symbolic expressions. Entries in the matrix are sympy expressions. @@ -181,7 +178,6 @@ def get_jacobian_matrix(self): J[i, j] = sympy.diff(expr, sym2) return J - def get_sub_system(self, symbols): r""" Return a new :python:`SystemOfShapes` instance which discards all symbols and equations except for those in :python:`symbols`. This is probably only sensible when the elements in :python:`symbols` do not dependend on any of the other symbols that will be thrown away. @@ -204,21 +200,19 @@ def get_sub_system(self, symbols): return SystemOfShapes(x_sub, A_sub, b_sub, c_sub, shapes_sub) - - def _generate_propagator_matrix(self, A): + def _generate_propagator_matrix(self, A) -> sympy.Matrix: r"""Generate the propagator matrix by matrix exponentiation.""" - # naive: calculate propagators in one step - # P_naive = _custom_simplify_expr(sympy.exp(A * sympy.Symbol(Config().output_timestep_symbol))) - - # optimized: be explicit about block diagonal elements; much faster! try: + # optimized: be explicit about block diagonal elements; much faster! + logging.debug("Computing propagator matrix (block-diagonal optimisation)...") blocks = get_block_diagonal_blocks(np.array(A)) - propagators = [sympy.simplify(sympy.exp(sympy.Matrix(block) * sympy.Symbol(Config().output_timestep_symbol))) for block in blocks] + propagators = [sympy.simplify(sympy.exp(sympy.Matrix(block) * sympy.Symbol(Config().output_timestep_symbol, real=True))) for block in blocks] P = sympy.Matrix(scipy.linalg.block_diag(*propagators)) except GetBlockDiagonalException: # naive: calculate propagators in one step - P = sympy.simplify(sympy.exp(A * sympy.Symbol(Config().output_timestep_symbol))) + logging.debug("Computing propagator matrix...") + P = _custom_simplify_expr(sympy.exp(A * sympy.Symbol(Config().output_timestep_symbol, real=True))) # check the result if sympy.I in sympy.preorder_traversal(P): @@ -226,6 +220,34 @@ def _generate_propagator_matrix(self, A): return P + def _merge_conditions(self, solver_dict): + r"""merge together conditions (a OR b OR c OR...) if the propagators and update_expressions are the same""" + + for condition, sub_solver_dict in solver_dict["conditions"].items(): + for condition2, sub_solver_dict2 in solver_dict["conditions"].items(): + if condition == condition2: + # don't check a condition against itself + continue + + if sub_solver_dict["propagators"] == sub_solver_dict2["propagators"] and sub_solver_dict["update_expressions"] == sub_solver_dict2["update_expressions"]: + # ``condition`` and ``condition2`` can be merged + solver_dict["conditions"]["(" + condition + ") || (" + condition2 + ")"] = sub_solver_dict + solver_dict["conditions"].pop(condition) + solver_dict["conditions"].pop(condition2) + return self._merge_conditions(solver_dict) + + return solver_dict + + def _remove_duplicate_conditions(self, conditions: Set[SymmetricEq]): + for cond in conditions: + inverted_eq = SymmetricEq(-cond.lhs, -cond.rhs) + if inverted_eq in conditions: + conditions.discard(cond) + return self._remove_duplicate_conditions(conditions) + + # nothing was removed + return conditions + def generate_propagator_solver(self, disable_singularity_detection: bool = False): r""" Generate the propagator matrix and symbolic expressions for propagator-based updates; return as JSON. @@ -240,25 +262,88 @@ def generate_propagator_solver(self, disable_singularity_detection: bool = False if not disable_singularity_detection: try: conditions = SingularityDetection.find_propagator_singularities(P, self.A_) + conditions = conditions.union(SingularityDetection.find_inhomogeneous_singularities(self.A_, self.b_)) + conditions = self._remove_duplicate_conditions(conditions) if conditions: # if there is one or more condition under which the solution goes to infinity... + logging.info("Under certain conditions, the default analytic solver contains singularities due to division by zero.") + logging.info("List of all conditions that result in a division by zero:") + for cond in conditions: + logging.info("\t" + str(cond.lhs) + " = " + str(cond.rhs)) + logging.info("Alternate solvers will be generated for each of these conditions (and combinations thereof).") + + # generate solver for the base case (with singularity conditions that are not met) + default_solver = self.generate_solver_dict_based_on_propagator_matrix_(P) + + # change the returned solver dictionary to include conditions + solver_dict = {"solver": "analytical", + "state_variables": default_solver["state_variables"], + "initial_values": default_solver["initial_values"], + "conditions": {"default": {"propagators": default_solver["propagators"], + "update_expressions": default_solver["update_expressions"]}}} + + # + # generate all combinations of conditions + # + + num_conditions = len(conditions) + condition_permutations = list(itertools.product([False, True], repeat=num_conditions)) + for condition_permutation in condition_permutations: + # each ``condition_permutation[i]`` is True/False corresponding to condition i + + cond_set = set() # cond_set is the set of conditions that have to hold + for i, cond_holds in enumerate(condition_permutation): + cond = list(conditions)[i] + if cond_holds: + # ``cond`` needs to hold for this propagator + cond_set.add(cond) + else: + # ``cond`` needs to **not** hold for this propagator + cond_set.add(sympy.Ne(cond.lhs, cond.rhs)) + + condition_str: str = " && ".join(["(" + str(eq.lhs) + (" == " if isinstance(eq, SymmetricEq) else "!=") + str(eq.rhs) + ")" for eq in cond_set]) + + logging.debug("Generating solver for condition: " + str(condition_str)) + + if not any([isinstance(eq, SymmetricEq) for eq in cond_set]): + # this is the default condition, only containing inequalities + continue + + conditional_A = self.A_.copy() + conditional_b = self.b_.copy() + conditional_c = self.c_.copy() + + for eq in cond_set: + if isinstance(eq, SymmetricEq): + # replace equalities (not inequalities) + conditional_A = conditional_A.subs(eq.lhs, eq.rhs) + conditional_b = conditional_b.subs(eq.lhs, eq.rhs) + conditional_c = conditional_c.subs(eq.lhs, eq.rhs) + + conditional_dynamics = SystemOfShapes(self.x_, conditional_A, conditional_b, conditional_c, self.shapes_) + solver_dict_conditional = conditional_dynamics.generate_propagator_solver(disable_singularity_detection=True) + solver_dict["conditions"][condition_str] = {"propagators": solver_dict_conditional["propagators"], + "update_expressions": solver_dict_conditional["update_expressions"]} + + solver_dict = self._merge_conditions(solver_dict) + + return solver_dict - logging.warning("Under certain conditions, the propagator matrix is singular (contains infinities).") - logging.warning("List of all conditions that result in a division by zero:") - for cond_set in conditions: - logging.warning("\t" + r" ∧ ".join([str(eq.lhs) + " = " + str(eq.rhs) for eq in cond_set])) except SingularityDetectionException: logging.warning("Could not check the propagator matrix for singularities.") + return self.generate_solver_dict_based_on_propagator_matrix_(P) + + def generate_solver_dict_based_on_propagator_matrix_(self, P: sympy.Matrix): + # # generate symbols for each nonzero entry of the propagator matrix # - P_sym = sympy.zeros(*P.shape) # each entry in the propagator matrix is assigned its own symbol P_expr = {} # the expression corresponding to each propagator symbol update_expr = {} # keys are str(variable symbol), values are str(expressions) that evaluate to the new value of the corresponding key - for row in range(P_sym.shape[0]): + for row in range(P.shape[0]): # assemble update expression for symbol ``self.x_[row]`` if not _is_zero(self.c_[row]): raise PropagatorGenerationException("For symbol " + str(self.x_[row]) + ": nonlinear part should be zero for propagators") @@ -267,10 +352,9 @@ def generate_propagator_solver(self, disable_singularity_detection: bool = False raise PropagatorGenerationException("For symbol " + str(self.x_[row]) + ": higher-order inhomogeneous ODEs are not supported") update_expr_terms = [] - for col in range(P_sym.shape[1]): + for col in range(P.shape[1]): if not _is_zero(P[row, col]): sym_str = Config().propagators_prefix + "__{}__{}".format(str(self.x_[row]), str(self.x_[col])) - P_sym[row, col] = sympy.parsing.sympy_parser.parse_expr(sym_str, global_dict=Shape._sympy_globals) P_expr[sym_str] = P[row, col] if row != col and not _is_zero(self.b_[col]): # the ODE for x_[row] depends on the inhomogeneous ODE of x_[col]. We can't solve this analytically in the general case (even though some specific cases might admit a solution) @@ -285,14 +369,6 @@ def generate_propagator_solver(self, disable_singularity_detection: bool = False update_expr_terms.append(Config().output_timestep_symbol + " * " + str(self.b_[row])) else: - # - # singularity detection on inhomogeneous part - # - - if not disable_singularity_detection: - SingularityDetection.find_inhomogeneous_singularities(expr=self.A_[row, row]) - - # # generate update expressions # @@ -303,7 +379,7 @@ def generate_propagator_solver(self, disable_singularity_detection: bool = False update_expr_terms.append(sym_str + " * (" + str(self.x_[row]) + " - (" + str(particular_solution) + "))" + " + (" + str(particular_solution) + ")") update_expr[str(self.x_[row])] = " + ".join(update_expr_terms) - update_expr[str(self.x_[row])] = sympy.parsing.sympy_parser.parse_expr(update_expr[str(self.x_[row])], global_dict=Shape._sympy_globals) + update_expr[str(self.x_[row])] = _sympy_parse_real(update_expr[str(self.x_[row])], global_dict=Shape._sympy_globals) if not _is_zero(self.b_[row]): # only simplify in case an inhomogeneous term is present update_expr[str(self.x_[row])] = _custom_simplify_expr(update_expr[str(self.x_[row])]) @@ -319,7 +395,6 @@ def generate_propagator_solver(self, disable_singularity_detection: bool = False return solver_dict - def generate_numeric_solver(self, state_variables=None): r""" Generate the symbolic expressions for numeric integration state updates; return as JSON. @@ -335,7 +410,6 @@ def generate_numeric_solver(self, state_variables=None): return solver_dict - def reconstitute_expr(self, state_variables=None): r""" Reconstitute a sympy expression from a system of shapes (which is internally encoded in the form :math:`\mathbf{x}' = \mathbf{Ax} + \mathbf{b} + \mathbf{c}`). @@ -355,7 +429,7 @@ def reconstitute_expr(self, state_variables=None): else: update_expr_terms.append(str(y) + " * (" + str(self.A_[row, col]) + ")") update_expr[str(x)] = " + ".join(update_expr_terms) + " + (" + str(self.b_[row]) + ") + (" + str(self.c_[row]) + ")" - update_expr[str(x)] = sympy.parsing.sympy_parser.parse_expr(update_expr[str(x)], global_dict=Shape._sympy_globals) + update_expr[str(x)] = _sympy_parse_real(update_expr[str(x)], global_dict=Shape._sympy_globals) # custom expression simplification for name, expr in update_expr.items(): @@ -365,7 +439,6 @@ def reconstitute_expr(self, state_variables=None): return update_expr - def shape_order_from_system_matrix(self, idx: int) -> int: r"""Determine shape differential order from system matrix of symbol ``self.x_[idx]``""" N = self.A_.shape[0] @@ -378,7 +451,6 @@ def shape_order_from_system_matrix(self, idx: int) -> int: shape_order = sum(scc == scc[idx]) return shape_order - def get_connected_symbols(self, idx: int) -> List[sympy.Symbol]: r"""Extract all symbols belonging to a shape with symbol ``self.x_[idx]`` from the system matrix. @@ -401,7 +473,6 @@ def get_connected_symbols(self, idx: int) -> List[sympy.Symbol]: idx = np.where(scc == scc[idx])[0] return [self.x_[i] for i in idx] - @classmethod def from_shapes(cls, shapes: List[Shape], parameters=None): r""" @@ -430,7 +501,7 @@ def from_shapes(cls, shapes: List[Shape], parameters=None): i = 0 for shape in shapes: - highest_diff_sym_idx = [k for k, el in enumerate(x) if el == sympy.Symbol(str(shape.symbol) + Config().differential_order_symbol * (shape.order - 1))][0] + highest_diff_sym_idx = [k for k, el in enumerate(x) if el == sympy.Symbol(str(shape.symbol) + Config().differential_order_symbol * (shape.order - 1), real=True)][0] shape_expr = shape.reconstitute_expr() # @@ -442,7 +513,6 @@ def from_shapes(cls, shapes: List[Shape], parameters=None): b[highest_diff_sym_idx] = inhom_term c[highest_diff_sym_idx] = nonlin_term - # # for higher-order shapes: mark derivatives x_i' = x_(i+1) for i < shape.order # diff --git a/requirements-testing.txt b/requirements-testing.txt new file mode 100644 index 00000000..392e123c --- /dev/null +++ b/requirements-testing.txt @@ -0,0 +1,2 @@ +pytest +semver diff --git a/requirements.txt b/requirements.txt index 226b48fa..8e042fb4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ sympy scipy numpy>=1.8.2 -pytest cython diff --git a/tests/test_analytic_integrator.py b/tests/test_analytic_integrator.py index 806758a6..0fef1c7c 100644 --- a/tests/test_analytic_integrator.py +++ b/tests/test_analytic_integrator.py @@ -19,6 +19,7 @@ # along with NEST. If not, see . # +import logging import sympy import numpy as np @@ -47,16 +48,16 @@ def test_analytic_integrator_alpha_function_of_time(self): h = 1E-3 # [s] T = 100E-3 # [s] - # # timeseries using ode-toolbox generated propagators # indict = _open_json("test_alpha_function_of_time.json") - solver_dict = odetoolbox.analysis(indict, disable_stiffness_check=True) + solver_dict = odetoolbox.analysis(indict, disable_stiffness_check=True, log_level=logging.DEBUG) assert len(solver_dict) == 1 solver_dict = solver_dict[0] assert solver_dict["solver"] == "analytical" + assert len(solver_dict["state_variables"]) == 2 ODE_INITIAL_VALUES = {"I": 0., "I__d": 0.} @@ -82,7 +83,7 @@ def test_analytic_integrator_alpha_function_of_time(self): state_ = analytic_integrator.get_value(t) state[use_caching]["timevec"].append(t) for sym, val in state_.items(): - state[use_caching][sym].append(val) + state[use_caching][str(sym)].append(val) for use_caching in [False, True]: for k, v in state[use_caching].items(): @@ -110,4 +111,4 @@ def test_analytic_integrator_alpha_function_of_time(self): np.testing.assert_allclose(state[True]["timevec"], timevec) np.testing.assert_allclose(state[True]["timevec"], state[False]["timevec"]) for sym, val in state_.items(): - np.testing.assert_allclose(state[True][sym], state[False][sym]) + np.testing.assert_allclose(state[True][str(sym)], state[False][str(sym)]) diff --git a/tests/test_analytic_solver_integration.py b/tests/test_analytic_solver_integration.py index 780698b6..52d526b9 100644 --- a/tests/test_analytic_solver_integration.py +++ b/tests/test_analytic_solver_integration.py @@ -22,6 +22,8 @@ import math import numpy as np import os +import pytest +import semver import sympy import sympy.parsing.sympy_parser import scipy @@ -44,6 +46,11 @@ from tests.test_utils import _open_json +sympy_version = semver.Version.parse(sympy.__version__) +SYMPY_VERSION_TOO_OLD = (sympy_version.major < 1) or (sympy_version.major == 1 and sympy_version.minor < 12) + + +@pytest.mark.skipif(SYMPY_VERSION_TOO_OLD, reason="Older versions of sympy hang on this test") class TestAnalyticSolverIntegration: r""" Numerical comparison between ode-toolbox calculated propagators, hand-calculated propagators expressed in Python, and numerical integration, for the iaf_cond_alpha neuron. @@ -162,7 +169,6 @@ def f(t, y): i_ex = numerical_sol[:2, :] v_rel = numerical_sol[2, :] - # # timeseries using hand-calculated propagators (only for alpha postsynaptic currents, not V_rel) # @@ -183,7 +189,6 @@ def f(t, y): i_ex__[:, step - 1] = i_ex_init i_ex__[:, step] = np.dot(P, i_ex__[:, step - 1]) - # # timeseries using ode-toolbox generated propagators # @@ -219,7 +224,7 @@ def f(t, y): state_ = analytic_integrator.get_value(t) state["timevec"].append(t) for sym, val in state_.items(): - state[sym].append(val) + state[str(sym)].append(val) for k, v in state.items(): state[k] = np.array(v) diff --git a/tests/test_double_exponential.py b/tests/test_double_exponential.py index 025d5839..4055afb1 100644 --- a/tests/test_double_exponential.py +++ b/tests/test_double_exponential.py @@ -20,7 +20,9 @@ # import numpy as np +import pytest from scipy.integrate import odeint +import sympy import odetoolbox @@ -29,7 +31,7 @@ try: import matplotlib as mpl - mpl.use('Agg') + mpl.use("Agg") import matplotlib.pyplot as plt INTEGRATION_TEST_DEBUG_PLOTS = True except ImportError: @@ -39,23 +41,32 @@ class TestDoubleExponential: r"""Test propagators generation for double exponential""" - def test_double_exponential(self): - r"""Test propagators generation for double exponential""" + @pytest.mark.parametrize("tau_1, tau_2", [(10., 2.), (10., 10.)]) + def test_double_exponential(self, tau_1, tau_2): + r"""Test propagators generation for double exponential + + Test for a case where tau_1 != tau_2 and where tau_1 == tau_2; this tests handling of numerical singularities. + + tau_1: decay time constant (ms) + tau_2: rise time constant (ms) + """ def time_to_max(tau_1, tau_2): r""" Time of maximum. """ - tmax = (np.log(tau_1) - np.log(tau_2)) / (1. / tau_2 - 1. / tau_1) - return tmax + if tau_1 == tau_2: + return tau_1 + + return (np.log(tau_1) - np.log(tau_2)) / (1. / tau_2 - 1. / tau_1) def unit_amplitude(tau_1, tau_2): r""" Scaling factor ensuring that amplitude of solution is one. """ tmax = time_to_max(tau_1, tau_2) - alpha = 1. / (np.exp(-tmax / tau_1) - np.exp(-tmax / tau_2)) - return alpha + + return 1. / (np.exp(-tmax / tau_1) - np.exp(-tmax / tau_2)) def flow(y, t, tau_1, tau_2, alpha, dt): r""" @@ -66,26 +77,26 @@ def flow(y, t, tau_1, tau_2, alpha, dt): return np.array([dy1dt, dy2dt]) + if tau_1 == tau_2: + alpha = 1. + else: + alpha = unit_amplitude(tau_1=tau_1, tau_2=tau_2) + indict = {"dynamics": [{"expression": "I_aux' = -I_aux / tau_1", "initial_values": {"I_aux": "0."}}, {"expression": "I' = I_aux - I / tau_2", "initial_values": {"I": "0"}}], "options": {"output_timestep_symbol": "__h"}, - "parameters": {"tau_1": "10", - "tau_2": "2", + "parameters": {"tau_1": str(tau_1), + "tau_2": str(tau_2), "w": "3.14", - "alpha": str(unit_amplitude(tau_1=10., tau_2=2.)), - "weighted_input_spikes": "0."}} + "alpha": str(alpha)}} w = 3.14 # weight (amplitude; pA) - tau_1 = 10. # decay time constant (ms) - tau_2 = 2. # rise time constant (ms) dt = .125 # time resolution (ms) T = 500. # simulation time (ms) input_spike_times = np.array([100., 300.]) # array of input spike times (ms) - alpha = unit_amplitude(tau_1, tau_2) - stimuli = [{"type": "list", "list": " ".join([str(el) for el in input_spike_times]), "variables": ["I_aux"]}] @@ -103,7 +114,7 @@ def flow(y, t, tau_1, tau_2, alpha, dt): N = int(np.ceil(T / dt) + 1) timevec = np.linspace(0., T, N) analytic_integrator = AnalyticIntegrator(solver_dict, spike_times) - analytic_integrator.shape_starting_values["I_aux"] = w * alpha * (1. / tau_2 - 1. / tau_1) + analytic_integrator.shape_starting_values[sympy.Symbol("I_aux", real=True)] = w * alpha analytic_integrator.set_initial_values(ODE_INITIAL_VALUES) analytic_integrator.reset() state = {"timevec": [], "I": [], "I_aux": []} @@ -111,7 +122,7 @@ def flow(y, t, tau_1, tau_2, alpha, dt): state_ = analytic_integrator.get_value(t) state["timevec"].append(t) for sym, val in state_.items(): - state[sym].append(val) + state[str(sym)].append(val) # solve with odeint ts0 = np.arange(0., input_spike_times[0] - dt / 2, dt) @@ -119,24 +130,24 @@ def flow(y, t, tau_1, tau_2, alpha, dt): ts2 = np.arange(input_spike_times[1], T + dt, dt) y_ = odeint(flow, [0., 0.], ts0, args=(tau_1, tau_2, alpha, dt)) - y_ = np.vstack([y_, odeint(flow, [y_[-1, 0] + w * alpha * (1. / tau_2 - 1. / tau_1), y_[-1, 1]], ts1, args=(tau_1, tau_2, alpha, dt))]) - y_ = np.vstack([y_, odeint(flow, [y_[-1, 0] + w * alpha * (1. / tau_2 - 1. / tau_1), y_[-1, 1]], ts2, args=(tau_1, tau_2, alpha, dt))]) + y_ = np.vstack([y_, odeint(flow, [y_[-1, 0] + w * alpha, y_[-1, 1]], ts1, args=(tau_1, tau_2, alpha, dt))]) + y_ = np.vstack([y_, odeint(flow, [y_[-1, 0] + w * alpha, y_[-1, 1]], ts2, args=(tau_1, tau_2, alpha, dt))]) - rec_I_interp = np.interp(np.hstack([ts0, ts1, ts2]), timevec, state['I']) - rec_I_aux_interp = np.interp(np.hstack([ts0, ts1, ts2]), timevec, state['I_aux']) + rec_I_interp = np.interp(np.hstack([ts0, ts1, ts2]), timevec, state["I"]) + rec_I_aux_interp = np.interp(np.hstack([ts0, ts1, ts2]), timevec, state["I_aux"]) if INTEGRATION_TEST_DEBUG_PLOTS: tmax = time_to_max(tau_1, tau_2) - mpl.rcParams['text.usetex'] = True + mpl.rcParams["text.usetex"] = True fig, ax = plt.subplots(nrows=2, figsize=(5, 4), dpi=300) - ax[0].plot(timevec, state['I_aux'], '--', lw=3, color='k', label=r'$I_\mathsf{aux}(t)$ (NEST)') - ax[0].plot(timevec, state['I'], '-', lw=3, color='k', label=r'$I(t)$ (NEST)') - ax[0].plot(np.hstack([ts0, ts1, ts2]), y_[:, 0], '--', lw=2, color='r', label=r'$I_\mathsf{aux}(t)$ (odeint)') - ax[0].plot(np.hstack([ts0, ts1, ts2]), y_[:, 1], '-', lw=2, color='r', label=r'$I(t)$ (odeint)') + ax[0].plot(timevec, state["I_aux"], "--", lw=3, color="k", label=r"$I_\mathsf{aux}(t)$ (ODEtb)") + ax[0].plot(timevec, state["I"], "-", lw=3, color="k", label=r"$I(t)$ (ODEtb)") + ax[0].plot(np.hstack([ts0, ts1, ts2]), y_[:, 0], "--", lw=2, color="r", label=r"$I_\mathsf{aux}(t)$ (odeint)") + ax[0].plot(np.hstack([ts0, ts1, ts2]), y_[:, 1], "-", lw=2, color="r", label=r"$I(t)$ (odeint)") for tin in input_spike_times: - ax[0].vlines(tin + tmax, ax[0].get_ylim()[0], ax[0].get_ylim()[1], colors='k', linestyles=':') + ax[0].vlines(tin + tmax, ax[0].get_ylim()[0], ax[0].get_ylim()[1], colors="k", linestyles=":") ax[1].semilogy(np.hstack([ts0, ts1, ts2]), np.abs(y_[:, 1] - rec_I_interp), label="I") ax[1].semilogy(np.hstack([ts0, ts1, ts2]), np.abs(y_[:, 0] - rec_I_aux_interp), linestyle="--", label="I_aux") @@ -146,9 +157,9 @@ def flow(y, t, tau_1, tau_2, alpha, dt): _ax.set_xlim(0., T + dt) _ax.legend() - ax[-1].set_xlabel(r'time (ms)') + ax[-1].set_xlabel(r"time (ms)") - fig.savefig('double_exp_test.png') + fig.savefig("double_exp_test_[tau_1=" + str(tau_1) + "]_[tau_2=" + str(tau_2) + "].png") np.testing.assert_allclose(y_[:, 1], rec_I_interp, atol=1E-7) diff --git a/tests/test_inhomogeneous.py b/tests/test_inhomogeneous.py index a80c15cc..b2f99a9b 100644 --- a/tests/test_inhomogeneous.py +++ b/tests/test_inhomogeneous.py @@ -51,7 +51,7 @@ def test_constant_rate(self, dt: float): cur_x = x0 timevec = np.arange(0., 100., dt) for step, t in enumerate(timevec): - state_ = analytic_integrator.get_value(t)["x"] + state_ = analytic_integrator.get_value(t)[sympy.Symbol("x", real=True)] actual.append(state_) cur_x = x0 + 42 * t @@ -70,8 +70,8 @@ def test_inhomogeneous_solver(self, dt, ode_definition): x0 = 0. - parameters_dict = {sympy.Symbol("U"): str(U), - sympy.Symbol("tau"): str(tau)} + parameters_dict = {sympy.Symbol("U", real=True): str(U), + sympy.Symbol("tau", real=True): str(tau)} shape = Shape.from_ode("x", ode_definition, initial_values={"x": str(x0)}, parameters=parameters_dict) @@ -80,7 +80,7 @@ def test_inhomogeneous_solver(self, dt, ode_definition): sys_of_shape = SystemOfShapes.from_shapes([shape], parameters=parameters_dict) solver_dict = sys_of_shape.generate_propagator_solver() - solver_dict["parameters"] = parameters_dict + solver_dict["parameters"] = {str(sym): expr for sym, expr in parameters_dict.items()} analytic_integrator = AnalyticIntegrator(solver_dict) analytic_integrator.set_initial_values({"x": str(x0)}) @@ -92,7 +92,7 @@ def test_inhomogeneous_solver(self, dt, ode_definition): timevec = np.arange(0., 100., dt) kernel = np.exp(-dt / tau) for step, t in enumerate(timevec): - state_ = analytic_integrator.get_value(t)["x"] + state_ = analytic_integrator.get_value(t)[sympy.Symbol("x", real=True)] actual.append(state_) correct.append(cur_x) cur_x = U + kernel * (cur_x - U) @@ -106,16 +106,16 @@ def test_inhomogeneous_simultaneous(self, dt: float): x0 = 0. - parameters_dict = {sympy.Symbol("U"): str(U), - sympy.Symbol("tau1"): str(tau), - sympy.Symbol("tau2"): str(tau)} + parameters_dict = {sympy.Symbol("U", real=True): str(U), + sympy.Symbol("tau1", real=True): str(tau), + sympy.Symbol("tau2", real=True): str(tau)} shape_x = Shape.from_ode("x", "(U - x) / tau1", initial_values={"x": str(x0)}, parameters=parameters_dict) shape_y = Shape.from_ode("y", "(1 - y) / tau2", initial_values={"y": str(x0)}, parameters=parameters_dict) sys_of_shape = SystemOfShapes.from_shapes([shape_x, shape_y], parameters=parameters_dict) solver_dict = sys_of_shape.generate_propagator_solver() - solver_dict["parameters"] = parameters_dict + solver_dict["parameters"] = {str(sym): expr for sym, expr in parameters_dict.items()} analytic_integrator = AnalyticIntegrator(solver_dict) analytic_integrator.set_initial_values({"x": str(x0), "y": str(x0)}) @@ -130,8 +130,8 @@ def test_inhomogeneous_simultaneous(self, dt: float): timevec = np.arange(0., 100., dt) kernel = np.exp(-dt / tau) for step, t in enumerate(timevec): - state_x = analytic_integrator.get_value(t)["x"] - state_y = analytic_integrator.get_value(t)["y"] + state_x = analytic_integrator.get_value(t)[sympy.Symbol("x", real=True)] + state_y = analytic_integrator.get_value(t)[sympy.Symbol("y", real=True)] actual_x.append(state_x) actual_y.append(state_y) correct_x.append(cur_x) @@ -146,7 +146,7 @@ def test_inhomogeneous_simultaneous(self, dt: float): def test_inhomogeneous_solver_second_order(self): r"""test failure to generate propagators for inhomogeneous 2nd order ODE""" tau = 10. # [s] - parameters_dict = {sympy.Symbol("tau"): str(tau)} + parameters_dict = {sympy.Symbol("tau", real=True): str(tau)} x0 = 0. x0d = 10. @@ -159,7 +159,7 @@ def test_inhomogeneous_solver_second_order(self): def test_inhomogeneous_solver_second_order_system(self): r"""test failure to generate propagators for inhomogeneous 2nd order ODE""" tau = 10. # [s] - parameters_dict = {sympy.Symbol("tau"): str(tau)} + parameters_dict = {sympy.Symbol("tau", real=True): str(tau)} x0 = 0. x0d = 10. @@ -185,8 +185,8 @@ def test_inhomogeneous_solver_second_order_combined_system(self): r"""test propagators generation for combined homogeneous/inhomogeneous ODEs""" tau = 10. # [s] E_L = -70. # [mV] - parameters_dict = {sympy.Symbol("tau"): str(tau), - sympy.Symbol("E_L"): str(E_L)} + parameters_dict = {sympy.Symbol("tau", real=True): str(tau), + sympy.Symbol("E_L", real=True): str(E_L)} x0 = 0. x0d = 10. diff --git a/tests/test_inhomogeneous_numerically_zero.py b/tests/test_inhomogeneous_numerically_zero.py index dc447f63..d381b7e0 100644 --- a/tests/test_inhomogeneous_numerically_zero.py +++ b/tests/test_inhomogeneous_numerically_zero.py @@ -23,6 +23,7 @@ import numpy as np import scipy.integrate +import sympy try: import matplotlib as mpl @@ -48,6 +49,7 @@ def _test_inhomogeneous_numerically_zero(self, late_ltd_check, late_ltp_check): assert len(solver_dict) == 1 solver_dict = solver_dict[0] assert solver_dict["solver"].startswith("analytic") + assert len(solver_dict["conditions"].keys()) == 2, "There should be exactly two conditions: default, and late_ltp_check == -late_ltd_check" print(solver_dict) solver_dict["parameters"] = {} @@ -56,25 +58,24 @@ def _test_inhomogeneous_numerically_zero(self, late_ltd_check, late_ltp_check): solver_dict["parameters"]["late_ltd_check"] = late_ltd_check solver_dict["parameters"]["late_ltp_check"] = late_ltp_check + z0 = 0.0 # set the initial condition dt = .1 T = 100. - timevec = np.arange(0., T, dt) # # integration using the ODE-toolbox analytic integrator # + timevec = np.arange(0., T, dt) analytic_integrator = AnalyticIntegrator(solver_dict) - analytic_integrator.set_initial_values({"z": 0.}) + analytic_integrator.set_initial_values({"z": z0}) analytic_integrator.reset() - actual = [analytic_integrator.get_value(t)["z"] for t in timevec] - + actual = [analytic_integrator.get_value(t)[sympy.Symbol("z", real=True)] for t in timevec] # # integration using scipy.integrate.odeint # - def ode_model(z, t, p, late_ltp_check, late_ltd_check, tau_z): """ Defines the differential equation for z. @@ -83,7 +84,6 @@ def ode_model(z, t, p, late_ltp_check, late_ltd_check, tau_z): dzdt = (((p * (1.0 - z) * late_ltp_check) - (p * (z + 0.5) * late_ltd_check))) / tau_z return dzdt - z0 = 0.0 # set the initial condition params = solver_dict["parameters"] ode_args = ( params["p"], @@ -95,7 +95,6 @@ def ode_model(z, t, p, late_ltp_check, late_ltd_check, tau_z): solution = scipy.integrate.odeint(ode_model, z0, timevec, args=ode_args, rtol=1E-12, atol=1E-12) correct = solution.flatten().tolist() - # # plot # @@ -115,18 +114,15 @@ def ode_model(z, t, p, late_ltp_check, late_ltd_check, tau_z): fig.savefig("/tmp/test_propagators_[late_ltd_check=" + str(late_ltd_check) + "]_[late_ltp_check=" + str(late_ltp_check) + "].png") - # # test # np.testing.assert_allclose(correct, actual) - @pytest.mark.xfail(strict=True, raises=AssertionError) def test_inhomogeneous_numerically_zero(self): self._test_inhomogeneous_numerically_zero(late_ltd_check=1., late_ltp_check=-1.) - @pytest.mark.xfail(strict=True, raises=AssertionError) def test_inhomogeneous_numerically_zero_alt(self): self._test_inhomogeneous_numerically_zero(late_ltd_check=0., late_ltp_check=0.) diff --git a/tests/test_lin_const_coeff_and_homogeneous.py b/tests/test_lin_const_coeff_and_homogeneous.py index 02ba84d7..37cf8ec4 100644 --- a/tests/test_lin_const_coeff_and_homogeneous.py +++ b/tests/test_lin_const_coeff_and_homogeneous.py @@ -28,46 +28,42 @@ class TestLinConstCoeffAndHomogeneous: """Test homogeneous and linear-and-constant-coefficient judgements on individual ODEs""" - _parameters = {sympy.Symbol("a"): "1", - sympy.Symbol("b"): "3.14159"} + _parameters = {sympy.Symbol("a", real=True): "1", + sympy.Symbol("b", real=True): "3.14159"} def test_from_function(self): shape = Shape.from_function("I_in", "(e/tau_syn_in) * t * exp(-t/tau_syn_in)") assert shape.is_homogeneous() assert shape.is_lin_const_coeff() - assert shape.is_lin_const_coeff_in([sympy.Symbol("I_in"), sympy.Symbol("I_in__d")], parameters={sympy.Symbol("tau_syn_in"): "3.14159"}) - + assert shape.is_lin_const_coeff_in([sympy.Symbol("I_in", real=True), sympy.Symbol("I_in__d", real=True)], parameters={sympy.Symbol("tau_syn_in", real=True): "3.14159"}) def test_nonlinear_inhomogeneous(self): shape = Shape.from_ode("q", "(a - q**2) / b", initial_values={"q": "0."}, parameters=TestLinConstCoeffAndHomogeneous._parameters) assert not shape.is_homogeneous() assert not shape.is_lin_const_coeff() - assert not shape.is_lin_const_coeff_in([sympy.Symbol("q")], parameters=TestLinConstCoeffAndHomogeneous._parameters) - + assert not shape.is_lin_const_coeff_in([sympy.Symbol("q", real=True)], parameters=TestLinConstCoeffAndHomogeneous._parameters) def test_nonlinear_homogeneous(self): shape = Shape.from_ode("q", "-q**2 / b", initial_values={"q": "0."}, parameters=TestLinConstCoeffAndHomogeneous._parameters) assert shape.is_homogeneous() assert not shape.is_lin_const_coeff() - assert not shape.is_lin_const_coeff_in([sympy.Symbol("q")], parameters=TestLinConstCoeffAndHomogeneous._parameters) - + assert not shape.is_lin_const_coeff_in([sympy.Symbol("q", real=True)], parameters=TestLinConstCoeffAndHomogeneous._parameters) def test_from_homogeneous_ode(self): shape = Shape.from_ode("q", "-q / b", initial_values={"q": "0."}) assert shape.is_homogeneous() assert not shape.is_lin_const_coeff() - assert shape.is_lin_const_coeff_in([sympy.Symbol("q")], parameters=TestLinConstCoeffAndHomogeneous._parameters) - + assert shape.is_lin_const_coeff_in([sympy.Symbol("q", real=True)], parameters=TestLinConstCoeffAndHomogeneous._parameters) def test_from_homogeneous_ode_alternate(self): shape = Shape.from_ode("q", "(a - q) / b", initial_values={"q": "0."}, parameters=TestLinConstCoeffAndHomogeneous._parameters) assert not shape.is_homogeneous() assert shape.is_lin_const_coeff() - assert shape.is_lin_const_coeff_in([sympy.Symbol("q")], parameters=TestLinConstCoeffAndHomogeneous._parameters) + assert shape.is_lin_const_coeff_in([sympy.Symbol("q", real=True)], parameters=TestLinConstCoeffAndHomogeneous._parameters) # xfail case: forgot to specify parameters shape = Shape.from_ode("q", "(a - q) / b", initial_values={"q": "0."}) @@ -78,14 +74,14 @@ def test_from_homogeneous_ode_alternate(self): class TestLinConstCoeffAndHomogeneousSystem: """Test homogeneous and linear-and-constant-coefficient judgements on systems of ODEs""" - _parameters = {sympy.Symbol("I_e"): "1.", - sympy.Symbol("Tau"): "1.", - sympy.Symbol("C_m"): "1.", - sympy.Symbol("tau_syn_in"): "1.", - sympy.Symbol("tau_syn_ex"): "1."} + _parameters = {sympy.Symbol("I_e", real=True): "1.", + sympy.Symbol("Tau", real=True): "1.", + sympy.Symbol("C_m", real=True): "1.", + sympy.Symbol("tau_syn_in", real=True): "1.", + sympy.Symbol("tau_syn_ex", real=True): "1."} def test_system_of_equations(self): - all_symbols = [sympy.Symbol(n) for n in ["I_in", "I_in__d", "I_ex", "I_ex__d", "V_m"]] + all_symbols = [sympy.Symbol(n, real=True) for n in ["I_in", "I_in__d", "I_ex", "I_ex__d", "V_m"]] shape_inh = Shape.from_function("I_in", "(e/tau_syn_in) * t * exp(-t/tau_syn_in)") shape_exc = Shape.from_function("I_ex", "(e/tau_syn_ex) * t * exp(-t/tau_syn_ex)") diff --git a/tests/test_lorenz_attractor.py b/tests/test_lorenz_attractor.py index fe7c2e60..b67e81bf 100644 --- a/tests/test_lorenz_attractor.py +++ b/tests/test_lorenz_attractor.py @@ -21,6 +21,7 @@ import sympy import sympy.parsing.sympy_parser +from odetoolbox.sympy_helpers import _sympy_parse_real from tests.test_utils import _open_json @@ -41,9 +42,9 @@ def test_lorenz_attractor(self): assert len(solver_dict) == 1 solver_dict = solver_dict[0] assert solver_dict["solver"].startswith("numeric") - assert sympy.parsing.sympy_parser.parse_expr(solver_dict["update_expressions"]["x"], global_dict=Shape._sympy_globals).expand().simplify() \ - == sympy.parsing.sympy_parser.parse_expr("sigma*(-x + y)", global_dict=Shape._sympy_globals).expand().simplify() - assert sympy.parsing.sympy_parser.parse_expr(solver_dict["update_expressions"]["y"], global_dict=Shape._sympy_globals).expand().simplify() \ - == sympy.parsing.sympy_parser.parse_expr("rho*x - x*z - y", global_dict=Shape._sympy_globals).expand().simplify() - assert sympy.parsing.sympy_parser.parse_expr(solver_dict["update_expressions"]["z"], global_dict=Shape._sympy_globals).expand().simplify() \ - == sympy.parsing.sympy_parser.parse_expr("-beta*z + x*y", global_dict=Shape._sympy_globals).expand().simplify() + assert _sympy_parse_real(solver_dict["update_expressions"]["x"], global_dict=Shape._sympy_globals).expand().simplify() \ + == _sympy_parse_real("sigma*(-x + y)", global_dict=Shape._sympy_globals).expand().simplify() + assert _sympy_parse_real(solver_dict["update_expressions"]["y"], global_dict=Shape._sympy_globals).expand().simplify() \ + == _sympy_parse_real("rho*x - x*z - y", global_dict=Shape._sympy_globals).expand().simplify() + assert _sympy_parse_real(solver_dict["update_expressions"]["z"], global_dict=Shape._sympy_globals).expand().simplify() \ + == _sympy_parse_real("-beta*z + x*y", global_dict=Shape._sympy_globals).expand().simplify() diff --git a/tests/test_mixed_integrator_numeric.py b/tests/test_mixed_integrator_numeric.py index 00e8e72a..7c999e91 100644 --- a/tests/test_mixed_integrator_numeric.py +++ b/tests/test_mixed_integrator_numeric.py @@ -74,7 +74,7 @@ def _run_simulation(indict, alias_spikes, integrator, params=None, **kwargs): T = 50E-3 # [s] initial_values = {"g_ex__d": 0., "g_in__d": 0.} # optionally override initial values - initial_values = {sympy.Symbol(k): v for k, v in initial_values.items()} + initial_values = {sympy.Symbol(k, real=True): v for k, v in initial_values.items()} spike_times = {"g_ex__d": np.array([10E-3]), "g_in__d": np.array([6E-3])} analysis_json, shape_sys, shapes = odetoolbox._analysis(indict, disable_stiffness_check=True, disable_analytic_solver=True, log_level="DEBUG", **kwargs) @@ -108,6 +108,7 @@ def _run_simulation(indict, alias_spikes, integrator, params=None, **kwargs): h_min_lower_bound=1E-12, raise_errors=True, debug=True) # debug needs to be True here to obtain the right return values + return h_min, h_avg, runtime, upper_bound_crossed, t_log, h_log, y_log, sym_list, analysis_json diff --git a/tests/test_propagator_solver_homogeneous.py b/tests/test_propagator_solver_homogeneous.py index b46b0b76..06ef482c 100644 --- a/tests/test_propagator_solver_homogeneous.py +++ b/tests/test_propagator_solver_homogeneous.py @@ -32,5 +32,8 @@ def test_propagator_solver_homogeneous(self): assert len(solver_dict) == 1 solver_dict = solver_dict[0] assert solver_dict["solver"] == "analytical" - assert float(solver_dict["propagators"]["__P__refr_t__refr_t"]) == 1. - assert solver_dict["propagators"]["__P__V_m__V_m"] == "1.0*exp(-__h/tau_m)" + + for cond_solver_dict in solver_dict["conditions"].values(): + assert float(cond_solver_dict["propagators"]["__P__refr_t__refr_t"]) == 1. + + assert solver_dict["conditions"]["default"]["propagators"]["__P__V_m__V_m"] == "1.0*exp(-__h/tau_m)" diff --git a/tests/test_singularity_detection.py b/tests/test_singularity_detection.py index 9af99215..4b3e6b5f 100644 --- a/tests/test_singularity_detection.py +++ b/tests/test_singularity_detection.py @@ -21,20 +21,33 @@ import io import logging +import numpy as np +import scipy import sympy import pytest +from odetoolbox.analytic_integrator import AnalyticIntegrator +from odetoolbox.spike_generator import SpikeGenerator + from .context import odetoolbox from tests.test_utils import _open_json from odetoolbox.singularity_detection import SingularityDetection -from odetoolbox.sympy_helpers import SymmetricEq +from odetoolbox.sympy_helpers import SymmetricEq, _sympy_parse_real + +try: + import matplotlib as mpl + mpl.use("Agg") + import matplotlib.pyplot as plt + INTEGRATION_TEST_DEBUG_PLOTS = True +except ImportError: + INTEGRATION_TEST_DEBUG_PLOTS = False class TestSingularityDetection: r"""Test singularity detection""" def test_is_matrix_defined_under_substitution(self): - tau_m, tau_r, C, h = sympy.symbols("tau_m, tau_r, C, h") + tau_m, tau_r, C, h = sympy.symbols("tau_m, tau_r, C, h", real=True) P = sympy.Matrix([[-1 / tau_r, 0, 0], [1, -1 / tau_r, 0], [0, 1 / C, -1 / tau_m]]) assert SingularityDetection._is_matrix_defined_under_substitution(P, set()) assert SingularityDetection._is_matrix_defined_under_substitution(P, set([SymmetricEq(tau_r, 1)])) @@ -44,10 +57,10 @@ def test_is_matrix_defined_under_substitution(self): def test_alpha_beta_kernels(self, kernel_to_use: str): r"""Test correctness of result for simple leaky integrate-and-fire neuron with biexponential postsynaptic kernel""" if kernel_to_use == "alpha": - tau_m, tau_s, C, h = sympy.symbols("tau_m, tau_s, C, h") + tau_m, tau_s, C, h = sympy.symbols("tau_m, tau_s, C, h", real=True) A = sympy.Matrix([[-1 / tau_s, 0, 0], [1, -1 / tau_s, 0], [0, 1 / C, -1 / tau_m]]) elif kernel_to_use == "beta": - tau_m, tau_d, tau_r, C, h = sympy.symbols("tau_m, tau_d, tau_r, C, h") + tau_m, tau_d, tau_r, C, h = sympy.symbols("tau_m, tau_d, tau_r, C, h", real=True) A = sympy.Matrix([[-1 / tau_d, 0, 0], [1, -1 / tau_r, 0], [0, 1 / C, -1 / tau_m]]) P = sympy.simplify(sympy.exp(A * h)) # Propagator matrix @@ -62,14 +75,186 @@ def test_alpha_beta_kernels(self, kernel_to_use: str): def test_more_than_one_solution(self): r"""Test the case where there is more than one element returned in a solution to an equation; in this example, for a quadratic input equation""" - A = sympy.Matrix([[sympy.parsing.sympy_parser.parse_expr("-1/(tau_s**2 - 3*tau_s - 42)")]]) + tau_s = sympy.Symbol("tau_s", real=True) + expr = _sympy_parse_real("-1/(tau_s**2 - 3*tau_s - 42)", local_dict={"tau_s": tau_s}) + A = sympy.Matrix([[expr]]) conditions = SingularityDetection._generate_singularity_conditions(A) assert len(conditions) == 2 - for cond_set in conditions: - for cond in cond_set: - assert sympy.Symbol("tau_s") == cond.lhs - assert cond.rhs == sympy.parsing.sympy_parser.parse_expr("3/2 + sqrt(177)/2") \ - or cond.rhs == sympy.parsing.sympy_parser.parse_expr("3/2 - sqrt(177)/2") + for cond in conditions: + assert sympy.Symbol("tau_s", real=True) == cond.lhs + assert cond.rhs == _sympy_parse_real("3/2 + sqrt(177)/2") \ + or cond.rhs == _sympy_parse_real("3/2 - sqrt(177)/2") + + +class TestSingularityInBothPropagatorAndInhomogeneous: + r""" + Test singularity mitigations when there is simultaneously a potential singularity in the propagator matrix as well as in the inhomogeneous terms. + """ + + @pytest.mark.parametrize("tau_1, tau_2", [(10., 2.), (10., 10.)]) + @pytest.mark.parametrize("late_ltd_check, late_ltp_check", [(3.14, 2.71), (0., 0.)]) + def test_singularity_in_both_propagator_and_inhomogeneous(self, tau_1, tau_2, late_ltd_check, late_ltp_check): + + def time_to_max(tau_1, tau_2): + r""" + Time of maximum. + """ + if tau_1 == tau_2: + return tau_1 + + return (np.log(tau_1) - np.log(tau_2)) / (1. / tau_2 - 1. / tau_1) + + def unit_amplitude(tau_1, tau_2): + r""" + Scaling factor ensuring that amplitude of solution is one. + """ + tmax = time_to_max(tau_1, tau_2) + + return 1. / (np.exp(-tmax / tau_1) - np.exp(-tmax / tau_2)) + + def double_exponential_ode_flow(y, t, tau_1, tau_2, alpha, dt): + r""" + Rhs of ODE system to be solved. + """ + dy1dt = -y[0] / tau_1 + dy2dt = y[0] - y[1] / tau_2 + + return np.array([dy1dt, dy2dt]) + + def inhomogeneous_ode_flow(z, t, late_ltp_check, late_ltd_check, tau_z): + """ + Defines the differential equation for z. + dz/dt = f(z, t) + """ + dzdt = (((1.0 - z) * late_ltp_check - (z + 0.5) * late_ltd_check)) / tau_z + return dzdt + + dt = .125 # time resolution (ms) + T = 48. # simulation time (ms) + + w = 3.14 # weight (amplitude; pA) + alpha = 1. + input_spike_times = np.array([10., 32.]) # array of input spike times (ms) + stimuli = [{"type": "list", + "list": " ".join([str(el) for el in input_spike_times]), + "variables": ["I_aux"]}] + + spike_times = SpikeGenerator.spike_times_from_json(stimuli, T) + + indict = {"dynamics": [{"expression": "I_aux' = -I_aux / tau_1", # double exponential + "initial_values": {"I_aux": "0."}}, + {"expression": "I' = I_aux - I / tau_2", # double exponential + "initial_values": {"I": "0"}}, + {"expression": "z' = (((1 - z) * late_ltp_check) - (z + 0.5) * late_ltd_check) / tau_z", + "initial_value": "1"}], # ODE with inhomogeneous term + "options": {"output_timestep_symbol": "__h"}, + "parameters": {"tau_1": str(tau_1), + "tau_2": str(tau_2), + "w": str(w), + "alpha": str(alpha)}} + + # + # integration using the ODE-toolbox analytic integrator + # + + timevec = np.arange(0., T, dt) + + solver_dict = odetoolbox.analysis(indict, log_level="DEBUG", disable_stiffness_check=True) + assert len(solver_dict) == 1 + solver_dict = solver_dict[0] + assert solver_dict["solver"] == "analytical" + + # solver_dict["parameters"] = {} + solver_dict["parameters"]["tau_z"] = 20. + solver_dict["parameters"]["late_ltd_check"] = late_ltd_check + solver_dict["parameters"]["late_ltp_check"] = late_ltp_check + + N = int(np.ceil(T / dt) + 1) + timevec = np.linspace(0., T, N) + analytic_integrator = AnalyticIntegrator(solver_dict, spike_times) + analytic_integrator.shape_starting_values[sympy.Symbol("I_aux", real=True)] = w * alpha + ODE_INITIAL_VALUES = {"I": 0., "I_aux": 0., "z": 0.} + analytic_integrator.set_initial_values(ODE_INITIAL_VALUES) + analytic_integrator.reset() + state = {"timevec": [], "I": [], "I_aux": [], "z": []} + for step, t in enumerate(timevec): + state_ = analytic_integrator.get_value(t) + state["timevec"].append(t) + for sym, val in state_.items(): + state[str(sym)].append(val) + + actual = [analytic_integrator.get_value(t)[sympy.Symbol("z", real=True)] for t in timevec] + + # + # integration using scipy.integrate.odeint + # + + z0 = 0.0 # set the initial condition + params = solver_dict["parameters"] + ode_args = ( + params["late_ltp_check"], + params["late_ltd_check"], + params["tau_z"] + ) + + solution = scipy.integrate.odeint(inhomogeneous_ode_flow, z0, timevec, args=ode_args, rtol=1E-12, atol=1E-12) + correct = solution.flatten().tolist() + + ts0 = np.arange(0., input_spike_times[0] + dt / 2, dt) + ts1 = np.arange(input_spike_times[0], input_spike_times[1] + dt / 2, dt) + ts2 = np.arange(input_spike_times[1], T + dt, dt) + + y_ = scipy.integrate.odeint(double_exponential_ode_flow, [0., 0.], ts0, args=(tau_1, tau_2, alpha, dt), rtol=1E-12, atol=1E-12) + y_ = np.vstack([y_[:-1, :], scipy.integrate.odeint(double_exponential_ode_flow, [y_[-1, 0] + w * alpha, y_[-1, 1]], ts1, args=(tau_1, tau_2, alpha, dt), rtol=1E-12, atol=1E-12)]) + y_ = np.vstack([y_[:-1, :], scipy.integrate.odeint(double_exponential_ode_flow, [y_[-1, 0] + w * alpha, y_[-1, 1]], ts2, args=(tau_1, tau_2, alpha, dt), rtol=1E-12, atol=1E-12)]) + + ts0 = ts0[:-1] + ts1 = ts1[:-1] + + if INTEGRATION_TEST_DEBUG_PLOTS: + + # + # plot the double exponential ODE + # + + tmax = time_to_max(tau_1, tau_2) + mpl.rcParams["text.usetex"] = True + + fig, ax = plt.subplots(nrows=6, figsize=(5, 4), dpi=600) + ax[0].plot(timevec, state["I"], "-", lw=3, color="k", label=r"$I(t)$ (ODEtb)") + ax[0].plot(np.hstack([ts0, ts1, ts2]), y_[:, 1], "-", lw=2, color="r", label=r"$I(t)$ (odeint)") + ax[1].plot(timevec, state["I_aux"], "--", lw=3, color="k", label=r"$I_\mathsf{aux}(t)$ (ODEtb)") + ax[1].plot(np.hstack([ts0, ts1, ts2]), y_[:, 0], "--", lw=2, color="r", label=r"$I_\mathsf{aux}(t)$ (odeint)") + + for tin in input_spike_times: + ax[0].vlines(tin + tmax, ax[0].get_ylim()[0], ax[0].get_ylim()[1], colors="k", linestyles=":") + ax[1].vlines(tin, ax[0].get_ylim()[0], ax[0].get_ylim()[1], colors="k", linestyles=":") + + ax[2].semilogy(np.hstack([ts0, ts1, ts2]), np.abs(y_[:, 1] - state["I"]), label="I") + ax[2].semilogy(np.hstack([ts0, ts1, ts2]), np.abs(y_[:, 0] - state["I_aux"]), linestyle="--", label="I_aux") + ax[2].set_ylabel("Error") + + ax[3].plot(timevec, correct, label="z (odeint)") + ax[4].plot(timevec, actual, label="z (ODEtb)") + ax[5].semilogy(timevec, np.abs(np.array(correct) - np.array(actual)), label="z") + ax[5].set_ylabel("Error") + ax[-1].set_xlabel("Time") + for _ax in ax: + _ax.set_xlim(0., T + dt) + _ax.legend() + _ax.grid() + if not _ax == ax[-1]: + _ax.set_xticklabels([]) + + fig.savefig("/tmp/test_singularity_simultaneous_[tau_1=" + str(tau_1) + "]_[tau_2=" + str(tau_2) + "]_[late_ltd_check=" + str(late_ltd_check) + "]_[late_ltp_check=" + str(late_ltp_check) + "].png") + + # + # test + # + + np.testing.assert_allclose(correct, actual) + np.testing.assert_allclose(y_[:, 1], state["I"], atol=1E-7) + np.testing.assert_allclose(y_[:, 0], state["I_aux"], atol=1E-7) class TestPropagatorSolverHomogeneous: diff --git a/tests/test_stiffness.py b/tests/test_stiffness.py index 672ab19a..5a41c43a 100644 --- a/tests/test_stiffness.py +++ b/tests/test_stiffness.py @@ -47,7 +47,6 @@ def test_canonical_stiff_system(self): assert len(result) == 1 \ and result[0]["solver"].endswith("explicit") - def test_morris_lecar_stiff(self): indict = _open_json("morris_lecar.json") diff --git a/tests/test_system_matrix_construction.py b/tests/test_system_matrix_construction.py index 9fcc6bcc..401b72df 100644 --- a/tests/test_system_matrix_construction.py +++ b/tests/test_system_matrix_construction.py @@ -22,6 +22,7 @@ import sympy from odetoolbox import _from_json_to_shapes +from odetoolbox.sympy_helpers import _sympy_parse_real from odetoolbox.system_of_shapes import SystemOfShapes from tests.test_utils import _open_json @@ -31,41 +32,39 @@ class TestSystemMatrixConstruction: def test_system_matrix_construction(self): indict = _open_json("system_matrix_test.json") shapes, parameters = _from_json_to_shapes(indict) - sigma, beta = sympy.symbols("sigma beta") + sigma, beta = sympy.symbols("sigma beta", real=True) shape_sys = SystemOfShapes.from_shapes(shapes, parameters=parameters) assert shape_sys.A_ == sympy.Matrix([[-sigma, sigma, 0], [0, 0, 0], [0, 0, -beta]]) - x, y, z = sympy.symbols("x y z") + x, y, z = sympy.symbols("x y z", real=True) assert shape_sys.c_ == sympy.Matrix([[0], [3 * z * x**2 - x * y], [x * y]]) - def test_lorenz_attractor(self): indict = _open_json("lorenz_attractor.json") shapes, parameters = _from_json_to_shapes(indict) - sigma, beta, rho = sympy.symbols("sigma beta rho") + sigma, beta, rho = sympy.symbols("sigma beta rho", real=True) shape_sys = SystemOfShapes.from_shapes(shapes, parameters=parameters) assert shape_sys.A_ == sympy.Matrix([[-sigma, sigma, 0], [rho, -1, 0], [0, 0, -beta]]) - x, y, z = sympy.symbols("x y z") + x, y, z = sympy.symbols("x y z", real=True) assert shape_sys.c_ == sympy.Matrix([[0], [-x * z], [x * y]]) - def test_morris_lecar(self): indict = _open_json("morris_lecar.json") shapes, parameters = _from_json_to_shapes(indict) shape_sys = SystemOfShapes.from_shapes(shapes, parameters=parameters) - C_m, g_Ca, g_K, g_L, E_Ca, E_K, E_L, I_ext = sympy.symbols("C_m g_Ca g_K g_L E_Ca E_K E_L I_ext") - assert shape_sys.A_ == sympy.Matrix([[sympy.parsing.sympy_parser.parse_expr("-500.0 * g_Ca / C_m - 1000.0 * g_L / C_m"), sympy.parsing.sympy_parser.parse_expr("1000.0 * E_K * g_K / C_m")], - [sympy.parsing.sympy_parser.parse_expr("0"), sympy.parsing.sympy_parser.parse_expr("0")]]) + C_m, g_Ca, g_K, g_L, E_Ca, E_K, E_L, I_ext = sympy.symbols("C_m g_Ca g_K g_L E_Ca E_K E_L I_ext", real=True) + assert shape_sys.A_ == sympy.Matrix([[_sympy_parse_real("-500.0 * g_Ca / C_m - 1000.0 * g_L / C_m"), _sympy_parse_real("1000.0 * E_K * g_K / C_m")], + [_sympy_parse_real("0"), _sympy_parse_real("0")]]) - V, W = sympy.symbols("V W") - assert shape_sys.b_ == sympy.Matrix([[sympy.parsing.sympy_parser.parse_expr("500.0 * E_Ca * g_Ca / C_m + 1000.0 * E_L * g_L / C_m + 1000.0 * I_ext / C_m")], - [sympy.parsing.sympy_parser.parse_expr("0")]]) - assert shape_sys.c_ == sympy.Matrix([[sympy.parsing.sympy_parser.parse_expr("500.0 * E_Ca * g_Ca * tanh(V / 15 + 1 / 15) / C_m - 1000.0 * V * W * g_K / C_m - 500.0 * V * g_Ca * tanh(V / 15 + 1 / 15) / C_m")], - [sympy.parsing.sympy_parser.parse_expr("-200.0 * W * cosh(V / 60) + 100.0 * cosh(V / 60) * tanh(V / 30) + 100.0 * cosh(V / 60)")]]) + V, W = sympy.symbols("V W", real=True) + assert shape_sys.b_ == sympy.Matrix([[_sympy_parse_real("500.0 * E_Ca * g_Ca / C_m + 1000.0 * E_L * g_L / C_m + 1000.0 * I_ext / C_m")], + [_sympy_parse_real("0")]]) + assert shape_sys.c_ == sympy.Matrix([[_sympy_parse_real("500.0 * E_Ca * g_Ca * tanh(V / 15 + 1 / 15) / C_m - 1000.0 * V * W * g_K / C_m - 500.0 * V * g_Ca * tanh(V / 15 + 1 / 15) / C_m")], + [_sympy_parse_real("-200.0 * W * cosh(V / 60) + 100.0 * cosh(V / 60) * tanh(V / 30) + 100.0 * cosh(V / 60)")]])