diff --git a/Examples/base/README.rst b/Examples/base/README.rst index 24202501..46da573a 100644 --- a/Examples/base/README.rst +++ b/Examples/base/README.rst @@ -3,4 +3,4 @@ Subclassing Examples ------------------------ -This section gathers examples which correspond to subclassing the :class:`easyscience.Objects.Base.BaseObj` class. +This section gathers examples which correspond to subclassing the :class:`easyscience.base_classes.ObjBase` class. diff --git a/Examples/base/plot_baseclass1.py b/Examples/base/plot_baseclass1.py index fa0539c8..b87c559e 100644 --- a/Examples/base/plot_baseclass1.py +++ b/Examples/base/plot_baseclass1.py @@ -1,8 +1,8 @@ """ -Subclassing BaseObj - Simple Pendulum +Subclassing ObjBase - Simple Pendulum ===================================== -This example shows how to subclass :class:`easyscience.Objects.Base.BaseObj` with parameters from -:class:`EasyScience.Objects.Base.Parameter`. For this example a simple pendulum will be modeled. +This example shows how to subclass :class:`easyscience.base_classes.ObjBase` with parameters from +:class:`EasyScience.variable.Parameter`. For this example a simple pendulum will be modeled. .. math:: y = A \sin (2 \pi f t + \phi ) @@ -17,8 +17,8 @@ import matplotlib.pyplot as plt import numpy as np -from easyscience.Objects.ObjectClasses import BaseObj -from easyscience.Objects.ObjectClasses import Parameter +from easyscience.base_classes import ObjBase +from easyscience.variable import Parameter # %% # Subclassing @@ -29,7 +29,7 @@ # embedded rST text block: -class Pendulum(BaseObj): +class Pendulum(ObjBase): def __init__(self, A: Parameter, f: Parameter, p: Parameter): super(Pendulum, self).__init__('SimplePendulum', A=A, f=f, p=p) @@ -41,13 +41,13 @@ def from_pars(cls, A: float = 1, f: float = 1, p: float = 0): return cls(A, f, p) def __call__(self, t): - return self.A.raw_value * np.sin(2 * np.pi * self.f.raw_value * t + self.p.raw_value) + return self.A.value * np.sin(2 * np.pi * self.f.value * t + self.p.value) def plot(self, time, axis=None, **kwargs): if axis is None: axis = plt else: - axis.set_title(f'A={self.A.raw_value}, F={self.f.raw_value}, P={self.p.raw_value}') + axis.set_title(f'A={self.A.value}, F={self.f.value}, P={self.p.value}') p = axis.plot(time, self(time), **kwargs) return p diff --git a/Examples/fitting/README.rst b/Examples/fitting/README.rst deleted file mode 100644 index e0c24c4d..00000000 --- a/Examples/fitting/README.rst +++ /dev/null @@ -1,6 +0,0 @@ -.. _fitting_examples: - -Fitting Examples ------------------------- - -This section gathers examples which correspond to fitting data. diff --git a/Examples/fitting/plot_constraints.py b/Examples/fitting/plot_constraints.py deleted file mode 100644 index b150bc82..00000000 --- a/Examples/fitting/plot_constraints.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -Constraints example -=================== -This example shows the usages of the different constraints. -""" - -from easyscience import Constraints -from easyscience.Objects.ObjectClasses import Parameter - -p1 = Parameter('p1', 1) -constraint = Constraints.NumericConstraint(p1, '<', 5) -p1.user_constraints['c1'] = constraint - -for value in range(4, 7): - p1.value = value - print(f'Set Value: {value}, Parameter Value: {p1}') - -# %% -# To include embedded rST, use a line of >= 20 ``#``'s or ``#%%`` between your -# rST and your code. This separates your example -# into distinct text and code blocks. You can continue writing code below the -# embedded rST text block: diff --git a/LICENSE b/LICENSE index c1ee0cf3..f21bf746 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ BSD 3-Clause License -Copyright (c) 2024, Easyscience contributors (https://github.com/EasyScience) +Copyright (c) 2025, Easyscience contributors (https://github.com/EasyScience) All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/docs/src/conf.py b/docs/src/conf.py index 2d95445a..9bf23650 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -51,9 +51,7 @@ intersphinx_mapping = { 'python': ('https://docs.python.org/3', None), - 'numpy': ('https://numpy.org/doc/stable/', None), - 'pint': ('https://pint.readthedocs.io/en/stable/', None), - 'xarray': ('https://xarray.pydata.org/en/stable/', None) + 'numpy': ('https://numpy.org/doc/stable/', None) } # -- General configuration --------------------------------------------------- diff --git a/docs/src/fitting/constraints.rst b/docs/src/fitting/constraints.rst deleted file mode 100644 index d92c87c2..00000000 --- a/docs/src/fitting/constraints.rst +++ /dev/null @@ -1,75 +0,0 @@ -====================== -Constraints -====================== - -Constraints are a fundamental component in non-trivial fitting operations. They can also be used to affirm the minimum/maximum of a parameter or tie parameters together in a model. - -Anatomy of a constraint ------------------------ - -A constraint is a rule which is applied to a **dependent** variable. This rule can consist of a logical operation, relation to one or more **independent** variables or an arbitrary function. - - -Constraints on Parameters -^^^^^^^^^^^^^^^^^^^^^^^^^ - -:class:`easyscience.Objects.Base.Parameter` has the properties `builtin_constraints` and `user_constraints`. These are dictionaries which correspond to constraints which are intrinsic and extrinsic to the Parameter. This means that on the value change of the Parameter firstly the `builtin_constraints` are evaluated, followed by the `user_constraints`. - - -Constraints on Fitting -^^^^^^^^^^^^^^^^^^^^^^ - -:class:`easyscience.fitting.Fitter` has the ability to evaluate user supplied constraints which effect the value of both fixed and non-fixed parameters. A good example of one such use case would be the ratio between two parameters, where you would create a :class:`easyscience.fitting.Constraints.ObjConstraint`. - -Using constraints ------------------ - -A constraint can be used in one of three ways; Assignment to a parameter, assignment to fitting or on demand. The first two are covered and on demand is shown below. - -.. code-block:: python - - from easyscience.fitting.Constraints import NumericConstraint - from easyscience.Objects.Base import Parameter - # Create an `a < 1` constraint - a = Parameter('a', 0.5) - constraint = NumericConstraint(a, '<=', 1) - # Evaluate the constraint on demand - a.value = 5.0 - constraint() - # A will now equal 1 - -Constraint Reference --------------------- - -.. minigallery:: easyscience.fitting.Constraints.NumericConstraint - :add-heading: Examples using `Constraints` - -Built-in constraints -^^^^^^^^^^^^^^^^^^^^ - -These are the built in constraints which you can use - -.. autoclass:: easyscience.fitting.Constraints.SelfConstraint - :members: +enabled - -.. autoclass:: easyscience.fitting.Constraints.NumericConstraint - :members: +enabled - -.. autoclass:: easyscience.fitting.Constraints.ObjConstraint - :members: +enabled - -.. autoclass:: easyscience.fitting.Constraints.FunctionalConstraint - :members: +enabled - -.. autoclass:: easyscience.fitting.Constraints.MultiObjConstraint - :members: +enabled - -User created constraints -^^^^^^^^^^^^^^^^^^^^^^^^ - -You can also make your own constraints by subclassing the :class:`easyscience.fitting.Constraints.ConstraintBase` class. For this at a minimum the abstract methods ``_parse_operator`` and ``__repr__`` need to be written. - -.. autoclass:: easyscience.fitting.Constraints.ConstraintBase - :members: - :private-members: - :special-members: __repr__ \ No newline at end of file diff --git a/docs/src/index.rst b/docs/src/index.rst index 3683a186..ca99dda5 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -56,7 +56,6 @@ Documentation :maxdepth: 3 fitting/introduction - fitting/constraints .. toctree:: :maxdepth: 2 diff --git a/docs/src/reference/base.rst b/docs/src/reference/base.rst index 59e9de32..ed3d05de 100644 --- a/docs/src/reference/base.rst +++ b/docs/src/reference/base.rst @@ -5,13 +5,13 @@ Parameters and Objects Descriptors =========== -.. autoclass:: easyscience.Objects.Variable.Descriptor +.. autoclass:: easyscience.variable.Descriptor :members: Parameters ========== -.. autoclass:: easyscience.Objects.Variable.Parameter +.. autoclass:: easyscience.variable.Parameter :members: :inherited-members: @@ -22,30 +22,17 @@ Super Classes and Collections Super Classes ============= -.. autoclass:: easyscience.Objects.ObjectClasses.BasedBase +.. autoclass:: easyscience.base_classes.BasedBase :members: :inherited-members: -.. autoclass:: easyscience.Objects.ObjectClasses.BaseObj +.. autoclass:: easyscience.base_classes.ObjBase :members: +_add_component :inherited-members: Collections =========== -.. autoclass:: easyscience.Objects.Groups.BaseCollection +.. autoclass:: easyscience.CollectionBase :members: :inherited-members: - -=============== -Data Containers -=============== - -.. autoclass:: easyscience.Datasets.xarray.EasyScienceDataarrayAccessor - :members: - :inherited-members: - -.. autoclass:: easyscience.Datasets.xarray.EasyScienceDatasetAccessor - :members: - :inherited-members: - diff --git a/examples_old/example1.py b/examples_old/example1.py index 342f9931..14609f7a 100644 --- a/examples_old/example1.py +++ b/examples_old/example1.py @@ -16,7 +16,7 @@ def fit_fun(x): # In the real case we would gust call the evaluation fn without reference to the BaseObj - return b.c.raw_value + b.m.raw_value * x + return b.c.value + b.m.value * x f = Fitter() diff --git a/examples_old/example1_dream.py b/examples_old/example1_dream.py index 74e90835..0b5621be 100644 --- a/examples_old/example1_dream.py +++ b/examples_old/example1_dream.py @@ -14,7 +14,7 @@ def fit_fun(x): # In the real case we would gust call the evaluation fn without reference to the BaseObj - return b.c.raw_value + b.m.raw_value * x + return b.c.value + b.m.value * x f = Fitter() diff --git a/examples_old/example2.py b/examples_old/example2.py index 3e0f4a7f..6e1acbd2 100644 --- a/examples_old/example2.py +++ b/examples_old/example2.py @@ -27,11 +27,11 @@ def _defaults(self): @property def gradient(self): - return self.m.raw_value + return self.m.value @property def intercept(self): - return self.c.raw_value + return self.c.value def fit_func(self, x: np.ndarray) -> np.ndarray: return self.gradient * x + self.intercept diff --git a/examples_old/example3.py b/examples_old/example3.py index d01a861e..49d6ede9 100644 --- a/examples_old/example3.py +++ b/examples_old/example3.py @@ -57,14 +57,14 @@ def gradient(self): if self.interface: return self.interface.get_value('m') else: - return self.m.raw_value + return self.m.value @property def intercept(self): if self.interface: return self.interface.get_value('c') else: - return self.c.raw_value + return self.c.value def fit_func(self, x: np.ndarray) -> np.ndarray: if self.interface: diff --git a/examples_old/example4.py b/examples_old/example4.py index 376d4486..720ff961 100644 --- a/examples_old/example4.py +++ b/examples_old/example4.py @@ -11,7 +11,7 @@ from easyscience import global_object from easyscience.fitting import Fitter -from easyscience.Objects.core import ComponentSerializer +from easyscience.Objects.component_serializer import ComponentSerializer from easyscience.Objects.ObjectClasses import BaseObj from easyscience.Objects.ObjectClasses import Parameter @@ -407,14 +407,14 @@ def gradient(self): if self.interface: return self.interface().get_value("m") else: - return self.m.raw_value + return self.m.value @property def intercept(self): if self.interface: return self.interface().get_value("c") else: - return self.c.raw_value + return self.c.value def __repr__(self): return f"Line: m={self.m}, c={self.c}" diff --git a/examples_old/example5_broken.py b/examples_old/example5_broken.py index dd5fa5e0..a7967ee8 100644 --- a/examples_old/example5_broken.py +++ b/examples_old/example5_broken.py @@ -12,7 +12,7 @@ from easyscience.fitting import Fitter from easyscience.Objects.Base import BaseObj from easyscience.Objects.Base import Parameter -from easyscience.Objects.core import ComponentSerializer +from easyscience.Objects.component_serializer import ComponentSerializer # from easyscience.Objects.Base import LoggedProperty from easyscience.Objects.Inferface import InterfaceFactoryTemplate @@ -325,14 +325,14 @@ def gradient(self): # if self.interface: # return self.interface().get_value('m') # else: - return self.m.raw_value + return self.m.value @property def intercept(self): # if self.interface: # return self.interface().get_value('c') # else: - return self.c.raw_value + return self.c.value def __repr__(self): return f"Line: m={self.m}, c={self.c}" diff --git a/examples_old/example6_broken.py b/examples_old/example6_broken.py index c93f7577..fe2dbbf2 100644 --- a/examples_old/example6_broken.py +++ b/examples_old/example6_broken.py @@ -12,7 +12,7 @@ from easyscience.fitting import Fitter from easyscience.Objects.ObjectClasses import BaseObj from easyscience.Objects.Variable import Parameter -from easyscience.Objects.core import ComponentSerializer +from easyscience.Objects.component_serializer import ComponentSerializer from easyscience.Objects.Inferface import InterfaceFactoryTemplate # This is a much more complex case where we have calculators, interfaces, interface factory and an diff --git a/examples_old/example_dataset2.py b/examples_old/example_dataset2.py index f2d415f9..b235bc4f 100644 --- a/examples_old/example_dataset2.py +++ b/examples_old/example_dataset2.py @@ -16,7 +16,7 @@ def fit_fun(x, *args, **kwargs): # In the real case we would gust call the evaluation fn without reference to the BaseObj - return b.c.raw_value + b.m.raw_value * x + return b.c.value + b.m.value * x f = Fitter() diff --git a/examples_old/example_dataset2pt2_broken.py b/examples_old/example_dataset2pt2_broken.py index 9c44f0a5..ef4dcbab 100644 --- a/examples_old/example_dataset2pt2_broken.py +++ b/examples_old/example_dataset2pt2_broken.py @@ -22,7 +22,7 @@ def fit_fun(x, *args, **kwargs): # In the real case we would gust call the evaluation fn without reference to the BaseObj - return b.c.raw_value + b.m.raw_value * x + return b.c.value + b.m.value * x nx = 1E3 diff --git a/examples_old/example_dataset3.py b/examples_old/example_dataset3.py index fb8bf7f0..4164dc5e 100644 --- a/examples_old/example_dataset3.py +++ b/examples_old/example_dataset3.py @@ -36,7 +36,7 @@ def fit_fun(x, *args, **kwargs): # In the real case we would gust call the evaluation fn without reference to the BaseObj - return np.sin(2*np.pi*(x[:, 0] + b.s_off.raw_value)) * np.cos(2*np.pi*(x[:, 1] + b.c_off.raw_value)) + return np.sin(2*np.pi*(x[:, 0] + b.s_off.value)) * np.cos(2*np.pi*(x[:, 1] + b.c_off.value)) f = Fitter() diff --git a/examples_old/example_dataset3pt2.py b/examples_old/example_dataset3pt2.py index ac2c9666..187306dd 100644 --- a/examples_old/example_dataset3pt2.py +++ b/examples_old/example_dataset3pt2.py @@ -34,7 +34,7 @@ def fit_fun(x, *args, **kwargs): # In the real case we would gust call the evaluation fn without reference to the BaseObj - return np.sin(2*np.pi*(x[:, 0] + b.s_off.raw_value)) * np.cos(2*np.pi*(x[:, 1] + b.c_off.raw_value)) + return np.sin(2*np.pi*(x[:, 0] + b.s_off.value)) * np.cos(2*np.pi*(x[:, 1] + b.c_off.value)) fig, ax = plt.subplots(2, 3, sharey=True, sharex=True) @@ -55,7 +55,7 @@ def fit_fun(x, *args, **kwargs): p1 = d[f'computed_{minimizer}'].plot(ax=ax[0, idx], cbar_kwargs={'cax': cbar_ax1}) p2 = d[f'dz_{minimizer}'].plot(ax=ax[1, idx], cbar_kwargs={'cax': cbar_ax2}) ax[0, idx].set_title(f'{minimizer}') - ax[1, idx].set_title('s_off - {:0.03f}\nc_off - {:0.03f}'.format(b.s_off.raw_value, b.c_off.raw_value)) + ax[1, idx].set_title('s_off - {:0.03f}\nc_off - {:0.03f}'.format(b.s_off.value, b.c_off.value)) ax[0, idx].set_aspect('equal', 'box') ax[1, idx].set_aspect('equal', 'box') fig.subplots_adjust(right=0.8) diff --git a/examples_old/example_dataset4.py b/examples_old/example_dataset4.py index 9e4c41cb..1e975d4e 100644 --- a/examples_old/example_dataset4.py +++ b/examples_old/example_dataset4.py @@ -18,7 +18,7 @@ def fit_fun(x, *args, **kwargs): # In the real case we would gust call the evaluation fn without reference to the BaseObj - return b.c.raw_value + b.m.raw_value * x + return b.c.value + b.m.value * x f = Fitter() diff --git a/examples_old/example_dataset4_2.py b/examples_old/example_dataset4_2.py index 375e5a11..9bba2ffd 100644 --- a/examples_old/example_dataset4_2.py +++ b/examples_old/example_dataset4_2.py @@ -33,7 +33,7 @@ def from_params(cls, amplitude: float = 1, phase: float = 0, period: float = 2*n def fit_fun(self, x, *args, **kwargs): # In the real case we would gust call the evaluation fn without reference to the BaseObj - return self.amplitude.raw_value * np.sin((x + self.phase.raw_value)/self.period.raw_value) + return self.amplitude.value * np.sin((x + self.phase.value)/self.period.value) b = Wavey.from_params() bb = Wavey.from_params(1.1, 0.1, 1.9*np.pi) diff --git a/pyproject.toml b/pyproject.toml index 734444b8..cfa5f735 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,8 +35,6 @@ dependencies = [ "lmfit", "numpy", "uncertainties", - "xarray", - "pint", # Only to ensure that unit is reported as dimensionless rather than empty string "scipp" ] @@ -75,6 +73,7 @@ packages = ["src"] [tool.hatch.build.targets.wheel] packages = ["src/easyscience"] +exclude = ["src/easyscience/legacy"] [tool.coverage.run] source = ["src/easyscience"] diff --git a/resources/scripts/generate_html.py b/resources/scripts/generate_html.py index 92a90981..0a399639 100644 --- a/resources/scripts/generate_html.py +++ b/resources/scripts/generate_html.py @@ -1,5 +1,3 @@ -__author__ = 'github.com/wardsimon' -__version__ = '0.0.1' import sys diff --git a/src/easyscience/Constraints.py b/src/easyscience/Constraints.py deleted file mode 100644 index 9628db9e..00000000 --- a/src/easyscience/Constraints.py +++ /dev/null @@ -1,498 +0,0 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project bool: - """ - Is the current constraint enabled. - - :return: Logical answer to if the constraint is enabled. - """ - return self._enabled - - @enabled.setter - def enabled(self, enabled_value: bool): - """ - Set the enabled state of the constraint. If the new value is the same as the current value only the state is - changed. - - ... note:: If the new value is ``True`` the constraint is also applied after enabling. - - :param enabled_value: New state of the constraint. - :return: None - """ - - if self._enabled == enabled_value: - return - elif enabled_value: - self.get_obj(self.dependent_obj_ids).enabled = False - self() - else: - self.get_obj(self.dependent_obj_ids).enabled = True - self._enabled = enabled_value - - def __call__(self, *args, no_set: bool = False, **kwargs): - """ - Method which applies the constraint - - :return: None if `no_set` is False, float otherwise. - """ - if not self.enabled: - if no_set: - return None - return - independent_objs = None - if isinstance(self.dependent_obj_ids, str): - dependent_obj = self.get_obj(self.dependent_obj_ids) - else: - raise AttributeError - if isinstance(self.independent_obj_ids, str): - independent_objs = self.get_obj(self.independent_obj_ids) - elif isinstance(self.independent_obj_ids, list): - independent_objs = [self.get_obj(obj_id) for obj_id in self.independent_obj_ids] - if independent_objs is not None: - value = self._parse_operator(independent_objs, *args, **kwargs) - else: - value = self._parse_operator(dependent_obj, *args, **kwargs) - - if not no_set: - toggle = False - if not dependent_obj.enabled: - dependent_obj.enabled = True - toggle = True - dependent_obj.value = value - if toggle: - dependent_obj.enabled = False - return value - - @abstractmethod - def _parse_operator(self, obj: V, *args, **kwargs) -> Number: - """ - Abstract method which contains the constraint logic - - :param obj: The object/objects which the constraint will use - :return: A numeric result of the constraint logic - """ - - @abstractmethod - def __repr__(self): - pass - - def get_obj(self, key: int) -> V: - """ - Get an EasyScience object from its unique key - - :param key: an EasyScience objects unique key - :return: EasyScience object - """ - return self._global_object.map.get_item_by_key(key) - - -C = TypeVar('C', bound=ConstraintBase) - - -class NumericConstraint(ConstraintBase): - """ - A `NumericConstraint` is a constraint whereby a dependent parameters value is something of an independent parameters - value. I.e. a < 1, a > 5 - """ - - def __init__(self, dependent_obj: V, operator: str, value: Number): - """ - A `NumericConstraint` is a constraint whereby a dependent parameters value is something of an independent - parameters value. I.e. a < 1, a > 5 - - :param dependent_obj: Dependent Parameter - :param operator: Relation to between the parameter and the values. e.g. ``=``, ``<``, ``>`` - :param value: What the parameters value should be compared against. - - :example: - - .. code-block:: python - - from easyscience.fitting.Constraints import NumericConstraint - from easyscience.Objects.Base import Parameter - # Create an `a < 1` constraint - a = Parameter('a', 0.2) - constraint = NumericConstraint(a, '<=', 1) - a.user_constraints['LEQ_1'] = constraint - # This works - a.value = 0.85 - # This triggers the constraint - a.value = 2.0 - # `a` is set to the maximum of the constraint (`a = 1`) - """ - super(NumericConstraint, self).__init__(dependent_obj, operator=operator, value=value) - - def _parse_operator(self, obj: V, *args, **kwargs) -> Number: - ## TODO Probably needs to be updated when DescriptorArray is implemented - - value = obj.value_no_call_back - - if isinstance(value, list): - value = np.array(value) - self.aeval.symtable['value1'] = value - self.aeval.symtable['value2'] = self.value - try: - self.aeval.eval(f'value3 = value1 {self.operator} value2') - logic = self.aeval.symtable['value3'] - if isinstance(logic, np.ndarray): - value[not logic] = self.aeval.symtable['value2'] - else: - if not logic: - value = self.aeval.symtable['value2'] - except Exception as e: - raise e - finally: - self.aeval = Interpreter() - return value - - def __repr__(self) -> str: - return f'{self.__class__.__name__} with `value` {self.operator} {self.value}' - - -class SelfConstraint(ConstraintBase): - """ - A `SelfConstraint` is a constraint which tests a logical constraint on a property of itself, similar to a - `NumericConstraint`. i.e. a > a.min. These constraints are usually used in the internal EasyScience logic. - """ - - def __init__(self, dependent_obj: V, operator: str, value: str): - """ - A `SelfConstraint` is a constraint which tests a logical constraint on a property of itself, similar to - a `NumericConstraint`. i.e. a > a.min. - - :param dependent_obj: Dependent Parameter - :param operator: Relation to between the parameter and the values. e.g. ``=``, ``<``, ``>`` - :param value: Name of attribute to be compared against - - :example: - - .. code-block:: python - - from easyscience.fitting.Constraints import SelfConstraint - from easyscience.Objects.Base import Parameter - # Create an `a < a.max` constraint - a = Parameter('a', 0.2, max=1) - constraint = SelfConstraint(a, '<=', 'max') - a.user_constraints['MAX'] = constraint - # This works - a.value = 0.85 - # This triggers the constraint - a.value = 2.0 - # `a` is set to the maximum of the constraint (`a = 1`) - """ - super(SelfConstraint, self).__init__(dependent_obj, operator=operator, value=value) - - def _parse_operator(self, obj: V, *args, **kwargs) -> Number: - value = obj.value_no_call_back - - self.aeval.symtable['value1'] = value - self.aeval.symtable['value2'] = getattr(obj, self.value) - try: - self.aeval.eval(f'value3 = value1 {self.operator} value2') - logic = self.aeval.symtable['value3'] - if isinstance(logic, np.ndarray): - value[not logic] = self.aeval.symtable['value2'] - else: - if not logic: - value = self.aeval.symtable['value2'] - except Exception as e: - raise e - finally: - self.aeval = Interpreter() - return value - - def __repr__(self) -> str: - return f'{self.__class__.__name__} with `value` {self.operator} obj.{self.value}' - - -class ObjConstraint(ConstraintBase): - """ - A `ObjConstraint` is a constraint whereby a dependent parameter is something of an independent parameter - value. E.g. a (Dependent Parameter) = 2* b (Independent Parameter) - """ - - def __init__(self, dependent_obj: V, operator: str, independent_obj: V): - """ - A `ObjConstraint` is a constraint whereby a dependent parameter is something of an independent parameter - value. E.g. a (Dependent Parameter) < b (Independent Parameter) - - :param dependent_obj: Dependent Parameter - :param operator: Relation to between the independent parameter and dependent parameter. e.g. ``2 *``, ``1 +`` - :param independent_obj: Independent Parameter - - :example: - - .. code-block:: python - - from easyscience.fitting.Constraints import ObjConstraint - from easyscience.Objects.Base import Parameter - # Create an `a = 2 * b` constraint - a = Parameter('a', 0.2) - b = Parameter('b', 1) - - constraint = ObjConstraint(a, '2*', b) - b.user_constraints['SET_A'] = constraint - b.value = 1 - # This triggers the constraint - a.value # Should equal 2 - - """ - super(ObjConstraint, self).__init__(dependent_obj, independent_obj=independent_obj, operator=operator) - self.external = True - - def _parse_operator(self, obj: V, *args, **kwargs) -> Number: - value = obj.value_no_call_back - - self.aeval.symtable['value1'] = value - try: - self.aeval.eval(f'value2 = {self.operator} value1') - value = self.aeval.symtable['value2'] - except Exception as e: - raise e - finally: - self.aeval = Interpreter() - return value - - def __repr__(self) -> str: - return f'{self.__class__.__name__} with `dependent_obj` = {self.operator} `independent_obj`' - - -class MultiObjConstraint(ConstraintBase): - """ - A `MultiObjConstraint` is similar to :class:`EasyScience.fitting.Constraints.ObjConstraint` except that it relates to - multiple independent objects. - """ - - def __init__( - self, - independent_objs: List[V], - operator: List[str], - dependent_obj: V, - value: Number, - ): - """ - A `MultiObjConstraint` is similar to :class:`EasyScience.fitting.Constraints.ObjConstraint` except that it relates - to one or more independent objects. - - E.g. - * a (Dependent Parameter) + b (Independent Parameter) = 1 - * a (Dependent Parameter) + b (Independent Parameter) - 2*c (Independent Parameter) = 0 - - :param independent_objs: List of Independent Parameters - :param operator: List of operators operating on the Independent Parameters - :param dependent_obj: Dependent Parameter - :param value: Value of the expression - - :example: - - **a + b = 1** - - .. code-block:: python - - from easyscience.fitting.Constraints import MultiObjConstraint - from easyscience.Objects.Base import Parameter - # Create an `a + b = 1` constraint - a = Parameter('a', 0.2) - b = Parameter('b', 0.3) - - constraint = MultiObjConstraint([b], ['+'], a, 1) - b.user_constraints['SET_A'] = constraint - b.value = 0.4 - # This triggers the constraint - a.value # Should equal 0.6 - - **a + b - 2c = 0** - - .. code-block:: python - - from easyscience.fitting.Constraints import MultiObjConstraint - from easyscience.Objects.Base import Parameter - # Create an `a + b - 2c = 0` constraint - a = Parameter('a', 0.5) - b = Parameter('b', 0.3) - c = Parameter('c', 0.1) - - constraint = MultiObjConstraint([b, c], ['+', '-2*'], a, 0) - b.user_constraints['SET_A'] = constraint - c.user_constraints['SET_A'] = constraint - b.value = 0.4 - # This triggers the constraint. Or it could be triggered by changing the value of c - a.value # Should equal 0.2 - - .. note:: This constraint is evaluated as ``dependent`` = ``value`` - SUM(``operator_i`` ``independent_i``) - """ - super(MultiObjConstraint, self).__init__( - dependent_obj, - independent_obj=independent_objs, - operator=operator, - value=value, - ) - self.external = True - - def _parse_operator(self, independent_objs: List[V], *args, **kwargs) -> Number: - - in_str = '' - value = None - for idx, obj in enumerate(independent_objs): - self.aeval.symtable['p' + str(self.independent_obj_ids[idx])] = obj.value_no_call_back - - in_str += ' p' + str(self.independent_obj_ids[idx]) - if idx < len(self.operator): - in_str += ' ' + self.operator[idx] - try: - self.aeval.eval(f'final_value = {self.value} - ({in_str})') - value = self.aeval.symtable['final_value'] - except Exception as e: - raise e - finally: - self.aeval = Interpreter() - return value - - def __repr__(self) -> str: - return f'{self.__class__.__name__}' - - -class FunctionalConstraint(ConstraintBase): - """ - Functional constraints do not depend on other parameters and as such can be more complex. - """ - - def __init__( - self, - dependent_obj: V, - func: Callable, - independent_objs: Optional[List[V]] = None, - ): - """ - Functional constraints do not depend on other parameters and as such can be more complex. - - :param dependent_obj: Dependent Parameter - :param func: Function to be evaluated in the form ``f(value, *args, **kwargs)`` - - :example: - - .. code-block:: python - - import numpy as np - from easyscience.fitting.Constraints import FunctionalConstraint - from easyscience.Objects.Base import Parameter - - a = Parameter('a', 0.2, max=1) - constraint = FunctionalConstraint(a, np.abs) - - a.user_constraints['abs'] = constraint - - # This triggers the constraint - a.value = 0.85 # `a` is set to 0.85 - # This triggers the constraint - a.value = -0.5 # `a` is set to 0.5 - """ - super(FunctionalConstraint, self).__init__(dependent_obj, independent_obj=independent_objs) - self.function = func - if independent_objs is not None: - self.external = True - - def _parse_operator(self, obj: V, *args, **kwargs) -> Number: - - self.aeval.symtable[f'f{id(self.function)}'] = self.function - value_str = f'r_value = f{id(self.function)}(' - if isinstance(obj, list): - for o in obj: - value_str += f'{o.value_no_call_back},' - - value_str = value_str[:-1] - else: - value_str += f'{obj.value_no_call_back}' - - value_str += ')' - try: - self.aeval.eval(value_str) - value = self.aeval.symtable['r_value'] - except Exception as e: - raise e - finally: - self.aeval = Interpreter() - return value - - def __repr__(self) -> str: - return f'{self.__class__.__name__}' - - -def cleanup_constraint(obj_id: str, enabled: bool): - try: - obj = global_object.map.get_item_by_key(obj_id) - obj.enabled = enabled - except ValueError: - if global_object.debug: - print(f'Object with ID {obj_id} has already been deleted') diff --git a/src/easyscience/Datasets/__init__.py b/src/easyscience/Datasets/__init__.py deleted file mode 100644 index 22e236a6..00000000 --- a/src/easyscience/Datasets/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project str: - """ - Get the common name of the DataSet. - - :return: Common name of the DataSet - :rtype: str - """ - return self._obj.attrs['name'] - - @name.setter - def name(self, new_name: str): - """ - Set the common name of the DataSet i.e could be experiment name... - - :param new_name: Common name of the DataSet - :type new_name: str - :return: None - :rtype: None - """ - self._obj.attrs['name'] = new_name - - @property - def description(self) -> str: - """ - Get a description of the DataSet - - :return: Description of the DataSet - :rtype: str - """ - return self._obj.attrs['description'] - - @description.setter - def description(self, new_description: str): - """ - Set the description of the DataSet - - :param new_description: Description of the DataSet - :type new_description: str - :return: None - :rtype: None - """ - self._obj.attrs['description'] = new_description - - @property - def url(self) -> str: - """ - Get the url of the DataSet - - :return: URL of the DataSet (empty if no URL) - :rtype: str - """ - return self._obj.attrs['url'] - - @url.setter - def url(self, new_url: str): - """ - Set the URL of the DataSet. This may be a DOI. - - :param new_url: New URL/DOI of the DataSet - :type new_url: str - :return:None - :rtype: None - """ - self._obj.attrs['url'] = new_url - - @property - def core_object(self): - """ - Get the core object associated to a DataSet. Note that this is called from a weakref. If the EasyScience obj is - garbage collected, None will be returned. - - :return: EasyScience object associated with the DataSet - :rtype: Any - """ - if self._core_object is None: - return None - return self._core_object() - - @core_object.setter - def core_object(self, new_core_object: Any): - """ - Associate an EasyScience object to a DataSet. - - :param new_core_object: EasyScience object to be associated to the DataSet - :type new_core_object: Any - :return: None - :rtype: None - """ - self._core_object = weakref.ref(new_core_object) - - def add_coordinate( - self, - coordinate_name: str, - coordinate_values: Union[List[T_], np.ndarray], - unit: str = '', - ): - """ - Add a coordinate to the DataSet. This can be then be assigned to one or more DataArrays. - - :param coordinate_name: Name of the coordinate e.g. `x` - :type coordinate_name: str - :param coordinate_values: Points for the coordinates - :type coordinate_values: Union[List[T_], numpy.ndarray] - :param unit: Unit associated with the coordinate - :type unit: str - :return: None - :rtype: None - """ - self._obj.coords[coordinate_name] = coordinate_values - self._obj.attrs['units'][coordinate_name] = ureg.Unit(unit) - - def remove_coordinate(self, coordinate_name: str): - """ - Remove a coordinate from the DataSet. Note that this will not remove the coordinate from DataArrays which have - already used the it! - - :param coordinate_name: Name of the coordinate to be removed - :type coordinate_name: str - :return: None - :rtype: None - """ - del self._obj.coords[coordinate_name] - del self._obj.attrs['units'][coordinate_name] - - def add_variable( - self, - variable_name, - variable_coordinates: Union[str, List[str]], - variable_values: Union[List[T_], np.ndarray], - variable_sigma: Union[List[T_], np.ndarray] = None, - unit: str = '', - auto_sigma: bool = False, - ): - """ - Create a DataArray from known coordinates and data, assign it to the dataset under a given name. Variances can - be calculated assuming gaussian distribution to 1 sigma. - - :param variable_name: Name of the DataArray which will be created and added to the dataset - :type variable_name: str - :param variable_coordinates: List of coordinates used in the supplied data array. - :type variable_coordinates: str, List[str] - :param variable_values: Numpy or list of data which will be assigned to the DataArray - :type variable_values: Union[numpy.ndarray, list] - :param variable_sigma: If the sigmas of the dataset are known, they can be supplied here. - :type variable_sigma: Union[numpy.ndarray, list] - :param unit: Unit associated with the DataArray - :type unit: str - :param auto_sigma: Should the sigma DataArray be automatically calculated assuming gaussian probability? - :type auto_sigma: bool - :return: None - :rtype: None - """ - - # Check if a user has supplied a coordinate as a string. Make it a list of strings - if isinstance(variable_coordinates, str): - variable_coordinates = [variable_coordinates] - - # The variable_coordinates can be any iterable object. Though we would assume list/tuple - if not isinstance(variable_coordinates, Iterable): - raise ValueError('The variable coordinates must be a list of strings') - - # Check to see if the user want to assign a coordinate which does not exist yet. - known_keys = self._obj.coords.keys() - for dimension in variable_coordinates: - if dimension not in known_keys: - raise ValueError(f'The supplied coordinate `{dimension}` must first be defined.') - - # Create the dataset. - self._obj[variable_name] = (variable_coordinates, variable_values) - - # Deal with sigmas - if variable_sigma is not None: - # CASE 1, user has supplied sigmas - if isinstance(variable_sigma, Callable): - # CASE 1-1, The sigmas are created by some kind of generator - self.sigma_generator(variable_name, variable_sigma) - elif isinstance(variable_sigma, np.ndarray): - # CASE 1-2, The sigmas are a numpy arrays - self.sigma_attach(variable_name, variable_sigma) - elif isinstance(variable_sigma, list): - # CASE 1-3, We have been given a list. Make it a numpy array - self.sigma_attach(variable_name, np.array(variable_sigma)) - else: - raise ValueError('User supplied sigmas must be of the form; Callable fn, numpy array, list') - else: - # CASE 2, No sigmas have been supplied. - if auto_sigma: - # CASE 2-1, Automatically generate the sigmas using gaussian probability - self.sigma_generator(variable_name) - - # Set units for the newly created DataArray - self._obj.attrs['units'][variable_name] = ureg.Unit(unit) - # If a sigma has been attached, attempt to work out the units. - if unit and variable_sigma is None and auto_sigma: - self._obj.attrs['units'][self.sigma_label_prefix + variable_name] = ureg.Unit(unit + ' ** 0.5') - else: - if auto_sigma: - self._obj.attrs['units'][self.sigma_label_prefix + variable_name] = ureg.Unit('') - - def remove_variable(self, variable_name: str): - """ - Remove a DataArray from the DataSet by supplied name. - - :param variable_name: Name of DataArray to be removed - :type variable_name: str - :return: None - :rtype: None - """ - del self._obj[variable_name] - - def sigma_generator( - self, - variable_label: str, - sigma_func: Callable = lambda x: np.sqrt(np.abs(x)), - label_prefix: str = None, - ): - """ - Generate sigmas off of a DataArray based on a function. - - :param variable_label: Name of the DataArray to perform the calculation on - :type variable_label: str - :param sigma_func: Function to generate the sigmas. Must be of the form f(x) and return an array of the same shape as the input. Default sqrt(\\|x\\|) - :type sigma_func: Callable - :param label_prefix: What prefix should be used to designate a sigma DataArray from a data DataArray - :type label_prefix: str - :return: None - :rtype: None - """ # noqa: E501 - sigma_values = sigma_func(self._obj[variable_label]) - self.sigma_attach(variable_label, sigma_values, label_prefix) - - def sigma_attach( - self, - variable_label: str, - sigma_values: Union[List[T_], np.ndarray, xr.DataArray], - label_prefix: str = None, - ): - """ - Attach an array of sigmas to the DataSet. - - :param variable_label: Name of the DataArray to perform the calculation on - :type variable_label: str - :param sigma_values: Array of sigmas in list, numpy or DataArray form - :type sigma_values: Union[List[T_], numpy.ndarray, xarray.DataArray] - :param label_prefix: What prefix should be used to designate a sigma DataArray from a data DataArray - :type label_prefix: str - :return: None - :rtype: None - """ - # Use the default sigma prefix if not defined. - if label_prefix is None: - label_prefix = self.sigma_label_prefix - - # Form the label for the new DataArray - sigma_label = label_prefix + variable_label - - # Map the original DataArray to the new sigma DataArray - self.__error_mapper[variable_label] = sigma_label - # Assign the sigma DataArray to the DataSet - if not isinstance(sigma_values, xr.DataArray): - self._obj[sigma_label] = ( - list(self._obj[variable_label].coords.keys()), - sigma_values, - ) - else: - self._obj[sigma_label] = sigma_values - - def generate_points(self, coordinates: List[str]) -> xr.DataArray: - """ - Generate an expanded DataArray of points which corresponds to broadcasted dimensions (`all_x`) which have been - concatenated along the second axis (`fit_dim`). - - :param coordinates: List of coordinate names to broadcast and concatenate along - :type coordinates: List[str] - :return: Broadcasted and concatenated coordinates - :rtype: xarray.DataArray - - .. code-block:: python - - x = [1, 2], y = [3, 4] - d = xr.DataArray() - d.EasyScience.add_coordinate('x', x) - d.EasyScience.add_coordinate('y', y) - points = d.EasyScience.generate_points(['x', 'y']) - print(points) - """ - - coords = [self._obj.coords[da] for da in coordinates] - c_array = [] - n_array = [] - for da in xr.broadcast(*coords): - c_array.append(da) - n_array.append(da.name) - - f = xr.concat(c_array, dim='fit_dim') - f = f.stack(all_x=n_array) - return f - - def fit( - self, - fitter, - data_arrays: list, - *args, - dask: str = 'forbidden', - fit_kwargs: dict = None, - fn_kwargs: dict = None, - vectorized: bool = False, - **kwargs, - ) -> List[FitResults]: - """ - Perform a fit on one or more DataArrays. This fit utilises a given fitter from `EasyScience.fitting.Fitter`, though - there are a few differences to a standard EasyScience fit. In particular, key-word arguments to control the - optimisation algorithm go in the `fit_kwargs` dictionary, fit function key-word arguments go in the `fn_kwargs` - and given key-word arguments control the `xarray.apply_ufunc` function. - - :param fitter: Fitting object which controls the fitting - :type fitter: EasyScience.fitting.Fitter - :param args: Arguments to go to the fit function - :type args: Any - :param dask: Dask control string. See `xarray.apply_ufunc` documentation - :type dask: str - :param fit_kwargs: Dictionary of key-word arguments to be supplied to the Fitting control - :type fit_kwargs: dict - :param fn_kwargs: Dictionary of key-words to be supplied to the fit function - :type fn_kwargs: dict - :param vectorized: Should the fit function be given dependents in a single object or split - :type vectorized: bool - :param kwargs: Key-word arguments for `xarray.apply_ufunc`. See `xarray.apply_ufunc` documentation - :type kwargs: Any - :return: Results of the fit - :rtype: List[FitResults] - """ - - if fn_kwargs is None: - fn_kwargs = {} - if fit_kwargs is None: - fit_kwargs = {} - if not isinstance(data_arrays, (list, tuple)): - data_arrays = [data_arrays] - - # In this case we are only fitting 1 dataset - if len(data_arrays) == 1: - variable_label = data_arrays[0] - dataset = self._obj[variable_label] - if self.__error_mapper.get(variable_label, False): - # Pull out any sigmas and send them to the fitter. - temp = self._obj[self.__error_mapper[variable_label]] - temp[xr.ufuncs.isnan(temp)] = 1e5 - fit_kwargs['weights'] = temp - # Perform a standard DataArray fit. - return dataset.EasyScience.fit( - fitter, - *args, - fit_kwargs=fit_kwargs, - fn_kwargs=fn_kwargs, - dask=dask, - vectorize=vectorized, - **kwargs, - ) - else: - # In this case we are fitting multiple datasets to the same fn! - bdim_f = [self._obj[p].EasyScience.fit_prep(fitter.fit_function) for p in data_arrays] - dim_names = [ - list(self._obj[p].dims.keys()) if isinstance(self._obj[p].dims, dict) else self._obj[p].dims - for p in data_arrays - ] - bdims = [bdim[0] for bdim in bdim_f] - fs = [bdim[1] for bdim in bdim_f] - old_fit_func = fitter.fit_function - - fn_array = [] - y_list = [] - for _idx, d in enumerate(bdims): - dims = self._obj[data_arrays[_idx]].dims - if isinstance(dims, dict): - dims = list(dims.keys()) - - def local_fit_func(x, *args, idx=None, **kwargs): - kwargs['vectorize'] = vectorized - res = xr.apply_ufunc( - fs[idx], - *bdims[idx], - *args, - dask=dask, - kwargs=fn_kwargs, - **kwargs, - ) - if dask != 'forbidden': - res.compute() - return res.stack(all_x=dim_names[idx]) - - y_list.append(self._obj[data_arrays[_idx]].stack(all_x=dims)) - fn_array.append(local_fit_func) - - def fit_func(x, *args, **kwargs): - res = [] - for idx in range(len(fn_array)): - res.append(fn_array[idx](x, *args, idx=idx, **kwargs)) - return xr.DataArray(np.concatenate(res, axis=0), coords={'all_x': x}, dims='all_x') - - fitter.initialize(fitter.fit_object, fit_func) - try: - if fit_kwargs.get('weights', None) is not None: - del fit_kwargs['weights'] - x = xr.DataArray(np.arange(np.sum([y.size for y in y_list])), dims='all_x') - y = xr.DataArray(np.concatenate(y_list, axis=0), coords={'all_x': x}, dims='all_x') - f_res = fitter.fit(x, y, **fit_kwargs) - f_res = check_sanity_multiple(f_res, [self._obj[p] for p in data_arrays]) - finally: - fitter.fit_function = old_fit_func - return f_res - - -@xr.register_dataarray_accessor('EasyScience') -class EasyScienceDataarrayAccessor: - """ - Accessor to extend an xarray DataArray to EasyScience. These functions can be accessed by `obj.EasyScience.func`. - - """ - - def __init__(self, xarray_obj: xr.DataArray): - self._obj = xarray_obj - self._core_object = None - self.sigma_label_prefix = 's_' - if self._obj.attrs.get('computation', None) is None: - self._obj.attrs['computation'] = { - 'precompute_func': None, - 'compute_func': None, - 'postcompute_func': None, - } - - def __empty_functional(self) -> Callable: - def outer(): - def empty_fn(input, *args, **kwargs): - return input - - return empty_fn - - class wrapper: - def __init__(obj): - obj.obj = self - obj.data = {} - obj.fn = outer() - - def __call__(self, *args, **kwargs): - return self.fn(*args, **kwargs) - - return wrapper() - - @property - def core_object(self): - """ - Get the core object associated to a DataArray. Note that this is called from a weakref. If the EasyScience obj is - garbage collected, None will be returned. - - :return: EasyScience object associated with the DataArray - :rtype: Any - """ - if self._core_object is None: - return None - return self._core_object() - - @core_object.setter - def core_object(self, new_core_object: Any): - """ - Set the core object associated to a dataset - - :param new_core_object: EasyScience object to be associated with the DataArray - :type new_core_object: Any - :return: None - :rtype: None - """ - self._core_object = weakref.ref(new_core_object) - - @property - def compute_func(self) -> Callable: - """ - Get the computational function which will be executed during a fit - - :return: Computational function applied to the DataArray - :rtype: Callable - """ - result = self._obj.attrs['computation']['compute_func'] - if result is None: - result = self.__empty_functional() - return result - - @compute_func.setter - def compute_func(self, new_computational_fn: Callable): - """ - Set the computational function which is called during a fit - - :param new_computational_fn: Computational function applied to the DataArray - :type new_computational_fn: Callable - :return: None - :rtype: None - """ - self._obj.attrs['computation']['compute_func'] = new_computational_fn - - @property - def precompute_func(self) -> Callable: - """ - Get the pre-computational function which will be executed before a fit - - :return: Computational function applied to the DataArray before fitting - :rtype: Callable - """ - result = self._obj.attrs['computation']['precompute_func'] - if result is None: - result = self.__empty_functional() - return result - - @precompute_func.setter - def precompute_func(self, new_computational_fn: Callable): - """ - Set the computational function which is called before a fit - - :param new_computational_fn: Computational function applied to the DataArray before fitting - :type new_computational_fn: Callable - :return: None - :rtype: None - """ - self._obj.attrs['computation']['precompute_func'] = new_computational_fn - - @property - def postcompute_func(self) -> Callable: - """ - Get the post-computational function which will be executed after a fit - - :return: Computational function applied to the DataArray after fitting - :rtype: Callable - """ - result = self._obj.attrs['computation']['postcompute_func'] - if result is None: - result = self.__empty_functional() - return result - - @postcompute_func.setter - def postcompute_func(self, new_computational_fn: Callable): - """ - Set the computational function which is called after a fit - - :param new_computational_fn: Computational function applied to the DataArray after fitting - :type new_computational_fn: Callable - :return: None - :rtype: None - """ - self._obj.attrs['computation']['postcompute_func'] = new_computational_fn - - def fit_prep(self, func_in: Callable, bdims=None, dask_chunks=None) -> Tuple[xr.DataArray, Callable]: - """ - Generate broadcasted coordinates for fitting and reform the fitting function into one which can handle xarrays. - - :param func_in: Function to be wrapped and made xarray fitting compatible. - :type func_in: Callable - :param bdims: Optional precomputed broadcasted dimensions. - :type bdims: xarray.DataArray - :param dask_chunks: How to split the broadcasted dimensions for dask. - :type dask_chunks: Tuple[int..] - :return: Tuple of broadcasted fit arrays and wrapped fit function. - :rtype: xarray.DataArray, Callable - """ - - if bdims is None: - coords = [self._obj.coords[da].transpose() for da in self._obj.dims] - bdims = xr.broadcast(*coords) - self._obj.attrs['computation']['compute_func'] = func_in - - def func(x, *args, vectorize: bool = False, **kwargs): - old_shape = x.shape - if not vectorize: - xs = [x_new.flatten() for x_new in [x, *args] if isinstance(x_new, np.ndarray)] - x_new = np.column_stack(xs) - if len(x_new.shape) > 1 and x_new.shape[1] == 1: - x_new = x_new.reshape((-1)) - result = self.compute_func(x_new, **kwargs) - else: - result = self.compute_func( - *[d for d in [x, args] if isinstance(d, np.ndarray)], - *[d for d in args if not isinstance(d, np.ndarray)], - **kwargs, - ) - if isinstance(result, np.ndarray): - result = result.reshape(old_shape) - result = self.postcompute_func(result) - return result - - return bdims, func - - def generate_points(self) -> xr.DataArray: - """ - Generate an expanded DataArray of points which corresponds to broadcasted dimensions (`all_x`) which have been - concatenated along the second axis (`fit_dim`). - - :return: Broadcasted and concatenated coordinates - :rtype: xarray.DataArray - """ - - coords = [self._obj.coords[da] for da in self._obj.dims] - c_array = [] - n_array = [] - for da in xr.broadcast(*coords): - c_array.append(da) - n_array.append(da.name) - - f = xr.concat(c_array, dim='fit_dim') - f = f.stack(all_x=n_array) - return f - - def fit( - self, - fitter, - *args, - fit_kwargs: dict = None, - fn_kwargs: dict = None, - vectorize: bool = False, - dask: str = 'forbidden', - **kwargs, - ) -> FitResults: - """ - Perform a fit on the given DataArray. This fit utilises a given fitter from `EasyScience.fitting.Fitter`, though - there are a few differences to a standard EasyScience fit. In particular, key-word arguments to control the - optimisation algorithm go in the `fit_kwargs` dictionary, fit function key-word arguments go in the `fn_kwargs` - and given key-word arguments control the `xarray.apply_ufunc` function. - - :param fitter: Fitting object which controls the fitting - :type fitter: EasyScience.fitting.Fitter - :param args: Arguments to go to the fit function - :type args: Any - :param dask: Dask control string. See `xarray.apply_ufunc` documentation - :type dask: str - :param fit_kwargs: Dictionary of key-word arguments to be supplied to the Fitting control - :type fit_kwargs: dict - :param fn_kwargs: Dictionary of key-words to be supplied to the fit function - :type fn_kwargs: dict - :param vectorize: Should the fit function be given dependents in a single object or split - :type vectorize: bool - :param kwargs: Key-word arguments for `xarray.apply_ufunc`. See `xarray.apply_ufunc` documentation - :type kwargs: Any - :return: Results of the fit - :rtype: FitResults - """ - - # Deal with any kwargs which has been given - if fn_kwargs is None: - fn_kwargs = {} - if fit_kwargs is None: - fit_kwargs = {} - old_fit_func = fitter.fit_function - - # Wrap and broadcast - bdims, f = self.fit_prep(fitter.fit_function) - dims = self._obj.dims - - # Find which coords we need - if isinstance(dims, dict): - dims = list(dims.keys()) - - # Wrap the wrap in a callable - def local_fit_func(x, *args, **kwargs): - """ - Function which will be called by the fitter. This will deal with sending the function the correct data. - """ - kwargs['vectorize'] = vectorize - res = xr.apply_ufunc(f, *bdims, *args, dask=dask, kwargs=fn_kwargs, **kwargs) - if dask != 'forbidden': - res.compute() - return res.stack(all_x=dims) - - # Set the new callable to the fitter and initialize - fitter.initialize(fitter.fit_object, local_fit_func) - # Make EasyScience.fitting.Fitter compatible `x` - x_for_fit = xr.concat(bdims, dim='fit_dim') - x_for_fit = x_for_fit.stack(all_x=[d.name for d in bdims]) - try: - # Deal with any sigmas if supplied - if fit_kwargs.get('weights', None) is not None: - fit_kwargs['weights'] = xr.DataArray( - np.array(fit_kwargs['weights']), - dims=['all_x'], - coords={'all_x': x_for_fit.all_x}, - ) - # Try to perform a fit - f_res = fitter.fit(x_for_fit, self._obj.stack(all_x=dims), **fit_kwargs) - f_res = check_sanity_single(f_res) - finally: - # Reset the fit function on the fitter to the old fit function. - fitter.fit_function = old_fit_func - return f_res - - -def check_sanity_single(fit_results: FitResults) -> FitResults: - """ - Convert the FitResults from a fitter compatible state to a recognizable DataArray state. - - :param fit_results: Results of a fit to be modified - :type fit_results: FitResults - :return: Modified fit results - :rtype: FitResults - """ - items = ['y_obs', 'y_calc', 'residual'] - - for item in items: - array = getattr(fit_results, item) - if isinstance(array, xr.DataArray): - array = array.unstack() - array.name = item - setattr(fit_results, item, array) - - x_array = fit_results.x - if isinstance(x_array, xr.DataArray): - fit_results.x.name = 'axes_broadcast' - x_array = x_array.unstack() - x_dataset = xr.Dataset() - dims = [dims for dims in x_array.dims if dims != 'fit_dim'] - for idx, dim in enumerate(dims): - x_dataset[dim + '_broadcast'] = x_array[idx] - x_dataset[dim + '_broadcast'].name = dim + '_broadcast' - fit_results.x_matrices = x_dataset - else: - fit_results.x_matrices = x_array - return fit_results - - -def check_sanity_multiple(fit_results: FitResults, originals: List[xr.DataArray]) -> List[FitResults]: - """ - Convert the multifit FitResults from a fitter compatible state to a list of recognizable DataArray states. - - :param fit_results: Results of a fit to be modified - :type fit_results: FitResults - :param originals: List of DataArrays which were fitted against, so we can resize and re-chunk the results - :type originals: List[xr.DataArray] - :return: Modified fit results - :rtype: List[FitResults] - """ - - return_results = [] - offset = 0 - for item in originals: - current_results = fit_results.__class__() - # Fill out the basic stuff.... - current_results.engine_result = fit_results.engine_result - current_results.minimizer_engine = fit_results.minimizer_engine - current_results.success = fit_results.success - current_results.p = fit_results.p - current_results.p0 = fit_results.p0 - # now the tricky stuff - current_results.x = item.EasyScience.generate_points() - current_results.y_obs = item.copy() - current_results.y_obs.name = f'{item.name}_obs' - current_results.y_calc = xr.DataArray( - fit_results.y_calc[offset : offset + item.size].data, - dims=item.dims, - coords=item.coords, - name=f'{item.name}_calc', - ) - offset += item.size - current_results.residual = current_results.y_calc - current_results.y_obs - current_results.residual.name = f'{item.name}_residual' - return_results.append(current_results) - return return_results diff --git a/src/easyscience/Objects/ObjectClasses.py b/src/easyscience/Objects/ObjectClasses.py deleted file mode 100644 index e6a159ad..00000000 --- a/src/easyscience/Objects/ObjectClasses.py +++ /dev/null @@ -1,367 +0,0 @@ -from __future__ import annotations - -__author__ = 'github.com/wardsimon' -__version__ = '0.1.0' - -# SPDX-FileCopyrightText: 2023 EasyScience contributors -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project Set[str]: - base_cls = getattr(self, '__old_class__', self.__class__) - spec = getfullargspec(base_cls.__init__) - names = set(spec.args[1:]) - return names - - def __reduce__(self): - """ - Make the class picklable. - Due to the nature of the dynamic class definitions special measures need to be taken. - - :return: Tuple consisting of how to make the object - :rtype: tuple - """ - state = self.encode() - cls = getattr(self, '__old_class__', self.__class__) - return cls.from_dict, (state,) - - @property - def unique_name(self) -> str: - """Get the unique name of the object.""" - return self._unique_name - - @unique_name.setter - def unique_name(self, new_unique_name: str): - """Set a new unique name for the object. The old name is still kept in the map. - - :param new_unique_name: New unique name for the object""" - if not isinstance(new_unique_name, str): - raise TypeError('Unique name has to be a string.') - self._unique_name = new_unique_name - self._global_object.map.add_vertex(self) - - @property - def name(self) -> str: - """ - Get the common name of the object. - - :return: Common name of the object - """ - return self._name - - @name.setter - def name(self, new_name: str): - """ - Set a new common name for the object. - - :param new_name: New name for the object - :return: None - """ - self._name = new_name - - @property - def interface(self) -> iF: - """ - Get the current interface of the object - """ - return self._interface - - @interface.setter - def interface(self, new_interface: iF): - """ - Set the current interface to the object and generate bindings if possible. iF.e. - ``` - def __init__(self, bar, interface=None, **kwargs): - super().__init__(self, **kwargs) - self.foo = bar - self.interface = interface # As final step after initialization to set correct bindings. - ``` - """ - self._interface = new_interface - if new_interface is not None: - self.generate_bindings() - - def generate_bindings(self): - """ - Generate or re-generate bindings to an interface (if exists) - - :raises: AttributeError - """ - if self.interface is None: - raise AttributeError('Interface error for generating bindings. `interface` has to be set.') - interfaceable_children = [ - key - for key in self._global_object.map.get_edges(self) - if issubclass(type(self._global_object.map.get_item_by_key(key)), BasedBase) - ] - for child_key in interfaceable_children: - child = self._global_object.map.get_item_by_key(child_key) - child.interface = self.interface - self.interface.generate_bindings(self) - - def switch_interface(self, new_interface_name: str): - """ - Switch or create a new interface. - """ - if self.interface is None: - raise AttributeError('Interface error for generating bindings. `interface` has to be set.') - self.interface.switch(new_interface_name) - self.generate_bindings() - - @property - def constraints(self) -> List[C]: - pars = self.get_parameters() - constraints = [] - for par in pars: - con: Dict[str, C] = par.user_constraints - for key in con.keys(): - constraints.append(con[key]) - return constraints - - def get_parameters(self) -> List[Parameter]: - """ - Get all parameter objects as a list. - - :return: List of `Parameter` objects. - """ - par_list = [] - for key, item in self._kwargs.items(): - if hasattr(item, 'get_parameters'): - par_list = [*par_list, *item.get_parameters()] - elif isinstance(item, Parameter): - par_list.append(item) - return par_list - - def _get_linkable_attributes(self) -> List[V]: - """ - Get all objects which can be linked against as a list. - - :return: List of `Descriptor`/`Parameter` objects. - """ - item_list = [] - for key, item in self._kwargs.items(): - if hasattr(item, '_get_linkable_attributes'): - item_list = [*item_list, *item._get_linkable_attributes()] - elif issubclass(type(item), (DescriptorBase)): - item_list.append(item) - return item_list - - def get_fit_parameters(self) -> List[Parameter]: - """ - Get all objects which can be fitted (and are not fixed) as a list. - - :return: List of `Parameter` objects which can be used in fitting. - """ - fit_list = [] - for key, item in self._kwargs.items(): - if hasattr(item, 'get_fit_parameters'): - fit_list = [*fit_list, *item.get_fit_parameters()] - elif isinstance(item, Parameter): - if item.enabled and not item.fixed: - fit_list.append(item) - return fit_list - - def __dir__(self) -> Iterable[str]: - """ - This creates auto-completion and helps out in iPython notebooks. - - :return: list of function and parameter names for auto-completion - """ - new_class_objs = list(k for k in dir(self.__class__) if not k.startswith('_')) - return sorted(new_class_objs) - - def __copy__(self) -> BasedBase: - """Return a copy of the object.""" - temp = self.as_dict(skip=['unique_name']) - new_obj = self.__class__.from_dict(temp) - return new_obj - - -if TYPE_CHECKING: - B = TypeVar('B', bound=BasedBase) - BV = TypeVar('BV', bound=ComponentSerializer) - - -class BaseObj(BasedBase): - """ - This is the base class for which all higher level classes are built off of. - NOTE: This object is serializable only if parameters are supplied as: - `BaseObj(a=value, b=value)`. For `Parameter` or `Descriptor` objects we can - cheat with `BaseObj(*[Descriptor(...), Parameter(...), ...])`. - """ - - def __init__( - self, - name: str, - unique_name: Optional[str] = None, - *args: Optional[BV], - **kwargs: Optional[BV], - ): - """ - Set up the base class. - - :param name: Name of this object - :param args: Any arguments? - :param kwargs: Fields which this class should contain - """ - super(BaseObj, self).__init__(name=name, unique_name=unique_name) - # If Parameter or Descriptor is given as arguments... - for arg in args: - if issubclass(type(arg), (BaseObj, DescriptorBase)): - kwargs[getattr(arg, 'name')] = arg - # Set kwargs, also useful for serialization - known_keys = self.__dict__.keys() - self._kwargs = kwargs - for key in kwargs.keys(): - if key in known_keys: - raise AttributeError('Kwargs cannot overwrite class attributes in BaseObj.') - if issubclass(type(kwargs[key]), (BasedBase, DescriptorBase)) or 'BaseCollection' in [ - c.__name__ for c in type(kwargs[key]).__bases__ - ]: - self._global_object.map.add_edge(self, kwargs[key]) - self._global_object.map.reset_type(kwargs[key], 'created_internal') - addLoggedProp( - self, - key, - self.__getter(key), - self.__setter(key), - get_id=key, - my_self=self, - test_class=BaseObj, - ) - - def _add_component(self, key: str, component: BV) -> None: - """ - Dynamically add a component to the class. This is an internal method, though can be called remotely. - The recommended alternative is to use typing, i.e. - - class Foo(Bar): - def __init__(self, foo: Parameter, bar: Parameter): - super(Foo, self).__init__(bar=bar) - self._add_component("foo", foo) - - Goes to: - class Foo(Bar): - foo: ClassVar[Parameter] - def __init__(self, foo: Parameter, bar: Parameter): - super(Foo, self).__init__(bar=bar) - self.foo = foo - - :param key: Name of component to be added - :param component: Component to be added - :return: None - """ - self._kwargs[key] = component - self._global_object.map.add_edge(self, component) - self._global_object.map.reset_type(component, 'created_internal') - addLoggedProp( - self, - key, - self.__getter(key), - self.__setter(key), - get_id=key, - my_self=self, - test_class=BaseObj, - ) - - def __setattr__(self, key: str, value: BV) -> None: - # Assume that the annotation is a ClassVar - old_obj = None - if ( - hasattr(self.__class__, '__annotations__') - and key in self.__class__.__annotations__ - and hasattr(self.__class__.__annotations__[key], '__args__') - and issubclass( - getattr(value, '__old_class__', value.__class__), - self.__class__.__annotations__[key].__args__, - ) - ): - if issubclass(type(getattr(self, key, None)), (BasedBase, DescriptorBase)): - old_obj = self.__getattribute__(key) - self._global_object.map.prune_vertex_from_edge(self, old_obj) - self._add_component(key, value) - else: - if hasattr(self, key) and issubclass(type(value), (BasedBase, DescriptorBase)): - old_obj = self.__getattribute__(key) - self._global_object.map.prune_vertex_from_edge(self, old_obj) - self._global_object.map.add_edge(self, value) - super(BaseObj, self).__setattr__(key, value) - # Update the interface bindings if something changed (BasedBase and Descriptor) - if old_obj is not None: - old_interface = getattr(self, 'interface', None) - if old_interface is not None: - self.generate_bindings() - - def __repr__(self) -> str: - return f"{self.__class__.__name__} `{getattr(self, 'name')}`" - - @staticmethod - def __getter(key: str) -> Callable[[BV], BV]: - def getter(obj: BV) -> BV: - return obj._kwargs[key] - - return getter - - @staticmethod - def __setter(key: str) -> Callable[[BV], None]: - def setter(obj: BV, value: float) -> None: - if issubclass(obj._kwargs[key].__class__, (DescriptorBase)) and not issubclass( - value.__class__, (DescriptorBase) - ): - obj._kwargs[key].value = value - else: - obj._kwargs[key] = value - - return setter - - # @staticmethod - # def __setter(key: str) -> Callable[[Union[B, V]], None]: - # def setter(obj: Union[V, B], value: float) -> None: - # if issubclass(obj._kwargs[key].__class__, Descriptor): - # if issubclass(obj._kwargs[key].__class__, Descriptor): - # obj._kwargs[key] = value - # else: - # obj._kwargs[key].value = value - # else: - # obj._kwargs[key] = value - # - # return setter diff --git a/src/easyscience/Objects/__init__.py b/src/easyscience/Objects/__init__.py deleted file mode 100644 index 22e236a6..00000000 --- a/src/easyscience/Objects/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project Set[str]: + base_cls = getattr(self, '__old_class__', self.__class__) + spec = getfullargspec(base_cls.__init__) + names = set(spec.args[1:]) + return names + + def __reduce__(self): + """ + Make the class picklable. + Due to the nature of the dynamic class definitions special measures need to be taken. + + :return: Tuple consisting of how to make the object + :rtype: tuple + """ + state = self.encode() + cls = getattr(self, '__old_class__', self.__class__) + return cls.from_dict, (state,) + + @property + def unique_name(self) -> str: + """Get the unique name of the object.""" + return self._unique_name + + @unique_name.setter + def unique_name(self, new_unique_name: str): + """Set a new unique name for the object. The old name is still kept in the map. + + :param new_unique_name: New unique name for the object""" + if not isinstance(new_unique_name, str): + raise TypeError('Unique name has to be a string.') + self._unique_name = new_unique_name + self._global_object.map.add_vertex(self) + + @property + def name(self) -> str: + """ + Get the common name of the object. + + :return: Common name of the object + """ + return self._name + + @name.setter + def name(self, new_name: str): + """ + Set a new common name for the object. + + :param new_name: New name for the object + :return: None + """ + self._name = new_name + + @property + def interface(self) -> InterfaceFactoryTemplate: + """ + Get the current interface of the object + """ + return self._interface + + @interface.setter + def interface(self, new_interface: InterfaceFactoryTemplate): + """ + Set the current interface to the object and generate bindings if possible. iF.e. + ``` + def __init__(self, bar, interface=None, **kwargs): + super().__init__(self, **kwargs) + self.foo = bar + self.interface = interface # As final step after initialization to set correct bindings. + ``` + """ + self._interface = new_interface + if new_interface is not None: + self.generate_bindings() + + def generate_bindings(self): + """ + Generate or re-generate bindings to an interface (if exists) + + :raises: AttributeError + """ + if self.interface is None: + raise AttributeError('Interface error for generating bindings. `interface` has to be set.') + interfaceable_children = [ + key + for key in self._global_object.map.get_edges(self) + if issubclass(type(self._global_object.map.get_item_by_key(key)), BasedBase) + ] + for child_key in interfaceable_children: + child = self._global_object.map.get_item_by_key(child_key) + child.interface = self.interface + self.interface.generate_bindings(self) + + def switch_interface(self, new_interface_name: str): + """ + Switch or create a new interface. + """ + if self.interface is None: + raise AttributeError('Interface error for generating bindings. `interface` has to be set.') + self.interface.switch(new_interface_name) + self.generate_bindings() + + def get_parameters(self) -> List[Parameter]: + """ + Get all parameter objects as a list. + + :return: List of `Parameter` objects. + """ + par_list = [] + for key, item in self._kwargs.items(): + if hasattr(item, 'get_parameters'): + par_list = [*par_list, *item.get_parameters()] + elif isinstance(item, Parameter): + par_list.append(item) + return par_list + + def _get_linkable_attributes(self) -> List[DescriptorBase]: + """ + Get all objects which can be linked against as a list. + + :return: List of `Descriptor`/`Parameter` objects. + """ + item_list = [] + for key, item in self._kwargs.items(): + if hasattr(item, '_get_linkable_attributes'): + item_list = [*item_list, *item._get_linkable_attributes()] + elif issubclass(type(item), (DescriptorBase)): + item_list.append(item) + return item_list + + def get_fit_parameters(self) -> List[Parameter]: + """ + Get all objects which can be fitted (and are not fixed) as a list. + + :return: List of `Parameter` objects which can be used in fitting. + """ + fit_list = [] + for key, item in self._kwargs.items(): + if hasattr(item, 'get_fit_parameters'): + fit_list = [*fit_list, *item.get_fit_parameters()] + elif isinstance(item, Parameter): + if item.independent and not item.fixed: + fit_list.append(item) + return fit_list + + def __dir__(self) -> Iterable[str]: + """ + This creates auto-completion and helps out in iPython notebooks. + + :return: list of function and parameter names for auto-completion + """ + new_class_objs = list(k for k in dir(self.__class__) if not k.startswith('_')) + return sorted(new_class_objs) + + def __copy__(self) -> BasedBase: + """Return a copy of the object.""" + temp = self.as_dict(skip=['unique_name']) + new_obj = self.__class__.from_dict(temp) + return new_obj + + diff --git a/src/easyscience/Objects/Groups.py b/src/easyscience/base_classes/collection_base.py similarity index 88% rename from src/easyscience/Objects/Groups.py rename to src/easyscience/base_classes/collection_base.py index 90d4f0c6..45d6f39c 100644 --- a/src/easyscience/Objects/Groups.py +++ b/src/easyscience/base_classes/collection_base.py @@ -1,12 +1,9 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project None: + def insert(self, index: int, value: Union[DescriptorBase, BasedBase]) -> None: """ Insert an object into the collection at an index. @@ -122,14 +118,14 @@ def insert(self, index: int, value: Union[V, B]) -> None: else: raise AttributeError('Only EasyScience objects can be put into an EasyScience group') - def __getitem__(self, idx: Union[int, slice]) -> Union[V, B]: + def __getitem__(self, idx: Union[int, slice]) -> Union[DescriptorBase, BasedBase]: """ Get an item in the collection based on its index. :param idx: index or slice of the collection. :type idx: Union[int, slice] :return: Object at index `idx` - :rtype: Union[Parameter, Descriptor, BaseObj, 'BaseCollection'] + :rtype: Union[Parameter, Descriptor, ObjBase, 'CollectionBase'] """ if isinstance(idx, slice): start, stop, step = idx.indices(len(self)) @@ -156,7 +152,7 @@ def __getitem__(self, idx: Union[int, slice]) -> Union[V, B]: keys = list(self._kwargs.keys()) return self._kwargs[keys[idx]] - def __setitem__(self, key: int, value: Union[B, V]) -> None: + def __setitem__(self, key: int, value: Union[BasedBase, DescriptorBase]) -> None: """ Set an item via it's index. @@ -238,7 +234,7 @@ def data(self) -> Tuple: def __repr__(self) -> str: return f"{self.__class__.__name__} `{getattr(self, 'name')}` of length {len(self)}" - def sort(self, mapping: Callable[[Union[B, V]], Any], reverse: bool = False) -> None: + def sort(self, mapping: Callable[[Union[BasedBase, DescriptorBase]], Any], reverse: bool = False) -> None: """ Sort the collection according to the given mapping. diff --git a/src/easyscience/base_classes/obj_base.py b/src/easyscience/base_classes/obj_base.py new file mode 100644 index 00000000..33316259 --- /dev/null +++ b/src/easyscience/base_classes/obj_base.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project None: + """ + Dynamically add a component to the class. This is an internal method, though can be called remotely. + The recommended alternative is to use typing, i.e. + + class Foo(Bar): + def __init__(self, foo: Parameter, bar: Parameter): + super(Foo, self).__init__(bar=bar) + self._add_component("foo", foo) + + Goes to: + class Foo(Bar): + foo: ClassVar[Parameter] + def __init__(self, foo: Parameter, bar: Parameter): + super(Foo, self).__init__(bar=bar) + self.foo = foo + + :param key: Name of component to be added + :param component: Component to be added + :return: None + """ + self._kwargs[key] = component + self._global_object.map.add_edge(self, component) + self._global_object.map.reset_type(component, 'created_internal') + addLoggedProp( + self, + key, + self.__getter(key), + self.__setter(key), + get_id=key, + my_self=self, + test_class=ObjBase, + ) + + def __setattr__(self, key: str, value: SerializerComponent) -> None: + # Assume that the annotation is a ClassVar + old_obj = None + if ( + hasattr(self.__class__, '__annotations__') + and key in self.__class__.__annotations__ + and hasattr(self.__class__.__annotations__[key], '__args__') + and issubclass( + getattr(value, '__old_class__', value.__class__), + self.__class__.__annotations__[key].__args__, + ) + ): + if issubclass(type(getattr(self, key, None)), (BasedBase, DescriptorBase)): + old_obj = self.__getattribute__(key) + self._global_object.map.prune_vertex_from_edge(self, old_obj) + self._add_component(key, value) + else: + if hasattr(self, key) and issubclass(type(value), (BasedBase, DescriptorBase)): + old_obj = self.__getattribute__(key) + self._global_object.map.prune_vertex_from_edge(self, old_obj) + self._global_object.map.add_edge(self, value) + super(ObjBase, self).__setattr__(key, value) + # Update the interface bindings if something changed (BasedBase and Descriptor) + if old_obj is not None: + old_interface = getattr(self, 'interface', None) + if old_interface is not None: + self.generate_bindings() + + def __repr__(self) -> str: + return f"{self.__class__.__name__} `{getattr(self, 'name')}`" + + @staticmethod + def __getter(key: str) -> Callable[[SerializerComponent], SerializerComponent]: + def getter(obj: SerializerComponent) -> SerializerComponent: + return obj._kwargs[key] + + return getter + + @staticmethod + def __setter(key: str) -> Callable[[SerializerComponent], None]: + def setter(obj: SerializerComponent, value: float) -> None: + if issubclass(obj._kwargs[key].__class__, (DescriptorBase)) and not issubclass( + value.__class__, (DescriptorBase) + ): + obj._kwargs[key].value = value + else: + obj._kwargs[key] = value + + return setter + + # @staticmethod + # def __setter(key: str) -> Callable[[Union[B, V]], None]: + # def setter(obj: Union[V, B], value: float) -> None: + # if issubclass(obj._kwargs[key].__class__, Descriptor): + # if issubclass(obj._kwargs[key].__class__, Descriptor): + # obj._kwargs[key] = value + # else: + # obj._kwargs[key].value = value + # else: + # obj._kwargs[key] = value + # + # return setter diff --git a/src/easyscience/fitting/calculators/__init__.py b/src/easyscience/fitting/calculators/__init__.py new file mode 100644 index 00000000..a3ca5d43 --- /dev/null +++ b/src/easyscience/fitting/calculators/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project List[str]: return [self.return_name(this_interface) for this_interface in self._interfaces] @property - def current_interface(self) -> _C: + def current_interface(self) -> ABCMeta: """ Returns the constructor for the currently selected interface. @@ -174,7 +168,7 @@ def generate_bindings(self, model, *args, ifun=None, **kwargs): prop._callback = item.make_prop(item_key) prop._callback.fset(prop_value) - def __call__(self, *args, **kwargs) -> _M: + def __call__(self, *args, **kwargs) -> None: return self.__interface_obj def __reduce__(self): @@ -233,6 +227,3 @@ def set_value(value): self.setter_fn(self.link_name, **{inner_key: value}) return set_value - - -iF = TypeVar('iF', bound=InterfaceFactoryTemplate) diff --git a/src/easyscience/fitting/fitter.py b/src/easyscience/fitting/fitter.py index daea7782..0cb67016 100644 --- a/src/easyscience/fitting/fitter.py +++ b/src/easyscience/fitting/fitter.py @@ -1,6 +1,6 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project list: - return self._minimizer.fit_constraints() - - def add_fit_constraint(self, constraint) -> None: - self._minimizer.add_fit_constraint(constraint) - - def remove_fit_constraint(self, index: int) -> None: - self._minimizer.remove_fit_constraint(index) - def make_model(self, pars=None) -> Callable: return self._minimizer.make_model(pars) @@ -84,9 +75,7 @@ def switch_minimizer(self, minimizer_enum: Union[AvailableMinimizers, str]) -> N print(f'minimizer should be set with enum {minimizer_enum}') minimizer_enum = from_string_to_enum(minimizer_enum) - constraints = self._minimizer.fit_constraints() self._update_minimizer(minimizer_enum) - self._minimizer.set_fit_constraint(constraints) def _update_minimizer(self, minimizer_enum: AvailableMinimizers) -> None: self._minimizer = factory(minimizer_enum=minimizer_enum, fit_object=self._fit_object, fit_function=self.fit_function) @@ -235,11 +224,7 @@ def inner_fit_callable( # Fit fit_fun_org = self._fit_function fit_fun_wrap = self._fit_function_wrapper(x_new, flatten=True) # This should be wrapped. - - # We change the fit function, so have to reset constraints - constraints = self._minimizer.fit_constraints() self.fit_function = fit_fun_wrap - self._minimizer.set_fit_constraint(constraints) f_res = self._minimizer.fit( x_fit, y_new, @@ -251,9 +236,8 @@ def inner_fit_callable( # Postcompute fit_result = self._post_compute_reshaping(f_res, x, y) - # Reset the function and constrains + # Reset the function self.fit_function = fit_fun_org - self._minimizer.set_fit_constraint(constraints) return fit_result return inner_fit_callable diff --git a/src/easyscience/fitting/minimizers/__init__.py b/src/easyscience/fitting/minimizers/__init__.py index eecbfc8a..b4de4c38 100644 --- a/src/easyscience/fitting/minimizers/__init__.py +++ b/src/easyscience/fitting/minimizers/__init__.py @@ -1,6 +1,6 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project List[ObjConstraint]: - return [*self._constraints, *self._object._constraints] @property def enum(self) -> AvailableMinimizers: @@ -66,18 +59,6 @@ def enum(self) -> AvailableMinimizers: def name(self) -> str: return self._minimizer_enum.name - def fit_constraints(self) -> List[ObjConstraint]: - return self._constraints - - def set_fit_constraint(self, constraints: List[ObjConstraint]): - self._constraints = constraints - - def add_fit_constraint(self, constraint: ObjConstraint): - self._constraints.append(constraint) - - def remove_fit_constraint(self, index: int) -> None: - del self._constraints[index] - @abstractmethod def fit( self, @@ -178,9 +159,9 @@ def all_methods() -> List[str]: @staticmethod @abstractmethod - def convert_to_par_object(obj): # todo after constraint changes, add type hint: obj: BaseObj + def convert_to_par_object(obj): # todo after constraint changes, add type hint: obj: ObjBase """ - Convert an `EasyScience.Objects.Base.Parameter` object to an engine Parameter object. + Convert an `EasyScience.variable.Parameter` object to an engine Parameter object. """ def _prepare_parameters(self, parameters: dict[str, float]) -> dict[str, float]: @@ -237,8 +218,6 @@ def _fit_function(x: np.ndarray, **kwargs): # Since we are calling the parameter fset will be called. # TODO Pre processing here - for constraint in self.fit_constraints(): - constraint() return_data = func(x) # TODO Loading or manipulating data here return return_data diff --git a/src/easyscience/fitting/minimizers/minimizer_bumps.py b/src/easyscience/fitting/minimizers/minimizer_bumps.py index 14df1d0f..ed2b140b 100644 --- a/src/easyscience/fitting/minimizers/minimizer_bumps.py +++ b/src/easyscience/fitting/minimizers/minimizer_bumps.py @@ -1,6 +1,6 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project List[BumpsPara :rtype: List[BumpsParameter] """ if par_list is None: - # Assume that we have a BaseObj for which we can obtain a list + # Assume that we have a ObjBase for which we can obtain a list par_list = self._object.get_fit_parameters() pars_obj = [self.__class__.convert_to_par_object(obj) for obj in par_list] return pars_obj @@ -160,7 +160,7 @@ def convert_to_pars_obj(self, par_list: Optional[List] = None) -> List[BumpsPara @staticmethod def convert_to_par_object(obj) -> BumpsParameter: """ - Convert an `EasyScience.Objects.Base.Parameter` object to a bumps Parameter object + Convert an `EasyScience.variable.Parameter` object to a bumps Parameter object :return: bumps Parameter compatible object. :rtype: BumpsParameter diff --git a/src/easyscience/fitting/minimizers/minimizer_dfo.py b/src/easyscience/fitting/minimizers/minimizer_dfo.py index 27f7eba4..bcc5afea 100644 --- a/src/easyscience/fitting/minimizers/minimizer_dfo.py +++ b/src/easyscience/fitting/minimizers/minimizer_dfo.py @@ -1,6 +1,6 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project L :return: lmfit Parameters compatible object """ if parameters is None: - # Assume that we have a BaseObj for which we can obtain a list + # Assume that we have a ObjBase for which we can obtain a list parameters = self._object.get_fit_parameters() lm_parameters = LMParameters().add_many([self.convert_to_par_object(parameter) for parameter in parameters]) return lm_parameters @@ -175,7 +175,7 @@ def convert_to_pars_obj(self, parameters: Optional[List[Parameter]] = None) -> L @staticmethod def convert_to_par_object(parameter: Parameter) -> LMParameter: """ - Convert an `EasyScience.Objects.Base.Parameter` object to a lmfit Parameter object. + Convert an EasyScience Parameter object to a lmfit Parameter object. :return: lmfit Parameter compatible object. :rtype: LMParameter diff --git a/src/easyscience/fitting/multi_fitter.py b/src/easyscience/fitting/multi_fitter.py index c812ff0e..a30bdcec 100644 --- a/src/easyscience/fitting/multi_fitter.py +++ b/src/easyscience/fitting/multi_fitter.py @@ -1,14 +1,13 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project NoReturn: self._parent.data = self._new_value -def property_stack_deco(arg: Union[str, Callable], begin_macro: bool = False) -> Callable: +def property_stack(arg: Union[str, Callable], begin_macro: bool = False) -> Callable: """ Decorate a `property` setter with undo/redo functionality This decorator can be used as: - @property_stack_deco + @property_stack def func() .... or - @property_stack_deco("This is the undo/redo text) + @property_stack("This is the undo/redo text) def func() .... diff --git a/src/easyscience/io/__init__.py b/src/easyscience/io/__init__.py new file mode 100644 index 00000000..b21e648c --- /dev/null +++ b/src/easyscience/io/__init__.py @@ -0,0 +1,12 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project any: + def encode(self, obj: SerializerComponent, skip: Optional[List[str]] = None, **kwargs) -> any: """ Abstract implementation of an encoder. @@ -56,7 +51,7 @@ def encode(self, obj: BV, skip: Optional[List[str]] = None, **kwargs) -> any: @abstractmethod def decode(cls, obj: Any) -> Any: """ - Re-create an EasyScience object from the output of an encoder. The default decoder is `DictSerializer`. + Re-create an EasyScience object from the output of an encoder. The default decoder is `SerializerDict`. :param obj: encoded EasyScience object :return: Reformed EasyScience object @@ -83,7 +78,7 @@ def _encode_objs(obj: Any) -> Dict[str, Any]: :param obj: any object to be encoded :param skip: List of field names as strings to skip when forming the encoded object - :param kwargs: Key-words to pass to `BaseEncoderDecoder` + :param kwargs: Key-words to pass to `SerializerBase` :return: JSON encoded dictionary """ @@ -117,7 +112,7 @@ def _encode_objs(obj: Any) -> Dict[str, Any]: def _convert_to_dict( self, - obj: BV, + obj: SerializerComponent, skip: Optional[List[str]] = None, full_encode: bool = False, **kwargs, @@ -129,20 +124,20 @@ def _convert_to_dict( skip = [] if full_encode: - new_obj = BaseEncoderDecoder._encode_objs(obj) + new_obj = SerializerBase._encode_objs(obj) if new_obj is not obj: return new_obj - d = {'@module': get_class_module(obj), '@class': obj.__class__.__name__} + d = {'@module': obj.__module__, '@class': obj.__class__.__name__} try: - parent_module = get_class_module(obj).split('.')[0] + parent_module = obj.__module__.split('.')[0] module_version = import_module(parent_module).__version__ # type: ignore d['@version'] = '{}'.format(module_version) except (AttributeError, ImportError): d['@version'] = None # type: ignore - spec, args = BaseEncoderDecoder.get_arg_spec(obj.__class__.__init__) + spec, args = SerializerBase.get_arg_spec(obj.__class__.__init__) if hasattr(obj, '_arg_spec'): args = obj._arg_spec @@ -150,7 +145,7 @@ def _convert_to_dict( def runner(o): if full_encode: - return BaseEncoderDecoder._encode_objs(o) + return SerializerBase._encode_objs(o) else: return o @@ -194,7 +189,7 @@ def runner(o): 'determine the dict format. Alternatively, ' 'you can implement both as_dict and from_dict.' ) - d[c] = recursive_encoder(a, skip=skip, encoder=self, full_encode=full_encode, **kwargs) + d[c] = self._recursive_encoder(a, skip=skip, encoder=self, full_encode=full_encode, **kwargs) if spec.varargs is not None and getattr(obj, spec.varargs, None) is not None: d.update({spec.varargs: getattr(obj, spec.varargs)}) if hasattr(obj, '_kwargs'): @@ -211,7 +206,7 @@ def runner(o): continue vv = redirect[k](obj) v_ = runner(vv) - d[k] = recursive_encoder( + d[k] = self._recursive_encoder( v_, skip=skip, encoder=self, @@ -240,9 +235,6 @@ def _convert_from_dict(d): if '@module' in d and '@class' in d: modname = d['@module'] classname = d['@class'] - # if classname in DictSerializer.REDIRECT.get(modname, {}): - # modname = DictSerializer.REDIRECT[modname][classname]["@module"] - # classname = DictSerializer.REDIRECT[modname][classname]["@class"] else: modname = None classname = None @@ -257,7 +249,7 @@ def _convert_from_dict(d): mod = __import__(modname, globals(), locals(), [classname], 0) if hasattr(mod, classname): cls_ = getattr(mod, classname) - data = {k: BaseEncoderDecoder._convert_from_dict(v) for k, v in d.items() if not k.startswith('@')} + data = {k: SerializerBase._convert_from_dict(v) for k, v in d.items() if not k.startswith('@')} return cls_(**data) elif np is not None and modname == 'numpy' and classname == 'array': if d['dtype'].startswith('complex'): @@ -265,38 +257,25 @@ def _convert_from_dict(d): return np.array(d['data'], dtype=d['dtype']) if issubclass(T_, (list, MutableSequence)): - return [BaseEncoderDecoder._convert_from_dict(x) for x in d] + return [SerializerBase._convert_from_dict(x) for x in d] return d - -if TYPE_CHECKING: - _ = TypeVar('EC', bound=BaseEncoderDecoder) - EC = Type[_] - - -def recursive_encoder(obj, skip: List[str] = [], encoder=None, full_encode=False, **kwargs): - """ - Walk through an object encoding it - """ - if encoder is None: - encoder = BaseEncoderDecoder() - T_ = type(obj) - if issubclass(T_, (list, tuple, MutableSequence)): - # Is it a core MutableSequence? + def _recursive_encoder(self, obj, skip: List[str] = [], encoder=None, full_encode=False, **kwargs): + """ + Walk through an object encoding it + """ + if encoder is None: + encoder = SerializerBase() + T_ = type(obj) + if issubclass(T_, (list, tuple, MutableSequence)): + # Is it a core MutableSequence? + if hasattr(obj, 'encode') and obj.__class__.__module__ != 'builtins': # strings have encode + return encoder._convert_to_dict(obj, skip, full_encode, **kwargs) + else: + return [self._recursive_encoder(it, skip, encoder, full_encode, **kwargs) for it in obj] + if isinstance(obj, dict): + return {kk: self._recursive_encoder(vv, skip, encoder, full_encode, **kwargs) for kk, vv in obj.items()} if hasattr(obj, 'encode') and obj.__class__.__module__ != 'builtins': # strings have encode return encoder._convert_to_dict(obj, skip, full_encode, **kwargs) - else: - return [recursive_encoder(it, skip, encoder, full_encode, **kwargs) for it in obj] - if isinstance(obj, dict): - return {kk: recursive_encoder(vv, skip, encoder, full_encode, **kwargs) for kk, vv in obj.items()} - if hasattr(obj, 'encode') and obj.__class__.__module__ != 'builtins': # strings have encode - return encoder._convert_to_dict(obj, skip, full_encode, **kwargs) - return obj + return obj - -def get_class_module(obj): - """ - Returns the REAL module of the class of the object. - """ - c = getattr(obj, '__old_class__', obj.__class__) - return c.__module__ diff --git a/src/easyscience/io/serializer_component.py b/src/easyscience/io/serializer_component.py new file mode 100644 index 00000000..15995412 --- /dev/null +++ b/src/easyscience/io/serializer_component.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project Any: + """ + Use an encoder to covert an EasyScience object into another format. Default is to a dictionary using `SerializerDict`. + + :param skip: List of field names as strings to skip when forming the encoded object + :param encoder: The encoder to be used for encoding the data. Default is `SerializerDict` + :param kwargs: Any additional key word arguments to be passed to the encoder + :return: encoded object containing all information to reform an EasyScience object. + """ + if encoder is None: + encoder = SerializerDict + encoder_obj = encoder() + return encoder_obj.encode(self, skip=skip, **kwargs) + + @classmethod + def decode(cls, obj: Any, decoder: Optional[SerializerBase] = None) -> Any: + """ + Re-create an EasyScience object from the output of an encoder. The default decoder is `SerializerDict`. + + :param obj: encoded EasyScience object + :param decoder: decoder to be used to reform the EasyScience object + :return: Reformed EasyScience object + """ + + if decoder is None: + decoder = SerializerDict + return decoder.decode(obj) + + def as_dict(self, skip: Optional[List[str]] = None) -> Dict[str, Any]: + """ + Convert an EasyScience object into a full dictionary using `SerializerDict`. + This is a shortcut for ```obj.encode(encoder=SerializerDict)``` + + :param skip: List of field names as strings to skip when forming the dictionary + :return: encoded object containing all information to reform an EasyScience object. + """ + + return self.encode(skip=skip, encoder=SerializerDict) + + @classmethod + def from_dict(cls, obj_dict: Dict[str, Any]) -> None: + """ + Re-create an EasyScience object from a full encoded dictionary. + + :param obj_dict: dictionary containing the serialized contents (from `SerializerDict`) of an EasyScience object + :return: Reformed EasyScience object + """ + + return cls.decode(obj_dict, decoder=SerializerDict) diff --git a/src/easyscience/io/serializer_dict.py b/src/easyscience/io/serializer_dict.py new file mode 100644 index 00000000..95b28f09 --- /dev/null +++ b/src/easyscience/io/serializer_dict.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +__author__ = "https://github.com/materialsvirtuallab/monty/blob/master/monty/json.py" +__version__ = "3.0.0" +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project SerializerComponent: + """ + Re-create an EasyScience object from the dictionary representation. + + :param d: Dict representation of an EasyScience object. + :return: EasyScience object. + """ + + return SerializerBase._convert_from_dict(d) \ No newline at end of file diff --git a/src/easyscience/Objects/job/__init__.py b/src/easyscience/job/__init__.py similarity index 100% rename from src/easyscience/Objects/job/__init__.py rename to src/easyscience/job/__init__.py diff --git a/src/easyscience/Objects/job/analysis.py b/src/easyscience/job/analysis.py similarity index 72% rename from src/easyscience/Objects/job/analysis.py rename to src/easyscience/job/analysis.py index 1d99ece1..512ae556 100644 --- a/src/easyscience/Objects/job/analysis.py +++ b/src/easyscience/job/analysis.py @@ -1,18 +1,16 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project np.ndarray: raise NotImplementedError("calculate_theory not implemented") @abstractmethod def fit(self, - x: Union[xr.DataArray, np.ndarray], - y: Union[xr.DataArray, np.ndarray], - e: Union[xr.DataArray, np.ndarray], + x: np.ndarray, + y: np.ndarray, + e: np.ndarray, **kwargs) -> None: raise NotImplementedError("fit not implemented") diff --git a/src/easyscience/Objects/job/experiment.py b/src/easyscience/job/experiment.py similarity index 68% rename from src/easyscience/Objects/job/experiment.py rename to src/easyscience/job/experiment.py index 1f2a63aa..807e0572 100644 --- a/src/easyscience/Objects/job/experiment.py +++ b/src/easyscience/job/experiment.py @@ -1,12 +1,12 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project BV: + def decode(cls, d: Dict) -> ComponentSerializer: """ :param d: Dict representation. :return: ComponentSerializer class. @@ -55,7 +55,7 @@ def decode(cls, d: Dict) -> BV: return BaseEncoderDecoder._convert_from_dict(d) @classmethod - def from_dict(cls, d: Dict[str, Any]) -> BV: + def from_dict(cls, d: Dict[str, Any]) -> ComponentSerializer: """ :param d: Dict representation. :return: ComponentSerializer class. @@ -70,7 +70,7 @@ class DataDictSerializer(DictSerializer): def encode( self, - obj: BV, + obj: ComponentSerializer, skip: Optional[List[str]] = None, full_encode: bool = False, **kwargs, @@ -95,7 +95,7 @@ def encode( return self._parse_dict(encoded) @classmethod - def decode(cls, d: Dict[str, Any]) -> BV: + def decode(cls, d: Dict[str, Any]) -> ComponentSerializer: """ This function is not implemented as a data dictionary does not contain the necessary information to re-form an EasyScience object. diff --git a/src/easyscience/Utils/io/json.py b/src/easyscience/legacy/json.py similarity index 57% rename from src/easyscience/Utils/io/json.py rename to src/easyscience/legacy/json.py index 6307c69b..6464e52a 100644 --- a/src/easyscience/Utils/io/json.py +++ b/src/easyscience/legacy/json.py @@ -3,27 +3,23 @@ __author__ = 'https://github.com/materialsvirtuallab/monty/blob/master/monty/json.py' __version__ = '3.0.0' -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project str: + def encode(self, obj: ComponentSerializer, skip: List[str] = []) -> str: """ Returns a json string representation of the ComponentSerializer object. """ @@ -35,16 +31,16 @@ def encode(self, obj: BV, skip: List[str] = []) -> str: return json.dumps(obj, cls=ENCODER) @classmethod - def decode(cls, data: str) -> BV: + def decode(cls, data: str) -> ComponentSerializer: return json.loads(data, cls=JsonDecoderTemplate) class JsonDataSerializer(BaseEncoderDecoder): - def encode(self, obj: BV, skip: List[str] = []) -> str: + def encode(self, obj: ComponentSerializer, skip: List[str] = []) -> str: """ Returns a json string representation of the ComponentSerializer object. """ - from easyscience.Utils.io.dict import DataDictSerializer + from .dict import DataDictSerializer ENCODER = type( JsonEncoderTemplate.__name__, @@ -60,7 +56,7 @@ def encode(self, obj: BV, skip: List[str] = []) -> str: return json.dumps(obj, cls=ENCODER) @classmethod - def decode(cls, data: str) -> BV: + def decode(cls, data: str) -> ComponentSerializer: raise NotImplementedError('It is not possible to reconstitute objects from data only objects.') @@ -121,51 +117,3 @@ def decode(self, s): """ d = json.JSONDecoder.decode(self, s) return self.__class__._converter(d) - - -def jsanitize(obj, strict=False, allow_bson=False): - """ - This method cleans an input json-like object, either a list or a dict or - some sequence, nested or otherwise, by converting all non-string - dictionary keys (such as int and float) to strings, and also recursively - encodes all objects using Monty's as_dict() protocol. - - Args: - obj: input json-like object. - strict (bool): This parameters sets the behavior when jsanitize - encounters an object it does not understand. If strict is True, - jsanitize will try to get the as_dict() attribute of the object. If - no such attribute is found, an attribute error will be thrown. If - strict is False, jsanitize will simply call str(object) to convert - the object to a string representation. - allow_bson (bool): This parameters sets the behavior when jsanitize - encounters an bson supported type such as objectid and datetime. If - True, such bson types will be ignored, allowing for proper - insertion into MongoDb databases. - - Returns: - Sanitized dict that can be json serialized. - """ - # if allow_bson and ( - # isinstance(obj, (datetime.datetime, bytes)) - # or (bson is not None and isinstance(obj, bson.objectid.ObjectId)) - # ): - # return obj - if isinstance(obj, (list, tuple)): - return [jsanitize(i, strict=strict, allow_bson=allow_bson) for i in obj] - if np is not None and isinstance(obj, np.ndarray): - return [jsanitize(i, strict=strict, allow_bson=allow_bson) for i in obj.tolist()] - if isinstance(obj, dict): - return {k.__str__(): jsanitize(v, strict=strict, allow_bson=allow_bson) for k, v in obj.items()} - if isinstance(obj, (int, float)): - return obj - if obj is None: - return None - - if not strict: - return obj.__str__() - - if isinstance(obj, str): - return obj.__str__() - - return jsanitize(obj.as_dict(), strict=strict, allow_bson=allow_bson) diff --git a/src/easyscience/Objects/core.py b/src/easyscience/legacy/legacy_core.py similarity index 90% rename from src/easyscience/Objects/core.py rename to src/easyscience/legacy/legacy_core.py index 0754040f..daddabee 100644 --- a/src/easyscience/Objects/core.py +++ b/src/easyscience/legacy/legacy_core.py @@ -1,12 +1,9 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project Any: + def encode(self, skip: Optional[List[str]] = None, encoder: Optional[BaseEncoderDecoder] = None, **kwargs) -> Any: """ Use an encoder to covert an EasyScience object into another format. Default is to a dictionary using `DictSerializer`. @@ -50,7 +47,7 @@ def encode(self, skip: Optional[List[str]] = None, encoder: Optional[EC] = None, return encoder_obj.encode(self, skip=skip, **kwargs) @classmethod - def decode(cls, obj: Any, decoder: Optional[EC] = None) -> Any: + def decode(cls, obj: Any, decoder: Optional[BaseEncoderDecoder] = None) -> Any: """ Re-create an EasyScience object from the output of an encoder. The default decoder is `DictSerializer`. @@ -85,7 +82,7 @@ def from_dict(cls, obj_dict: Dict[str, Any]) -> None: return cls.decode(obj_dict, decoder=DictSerializer) - def encode_data(self, skip: Optional[List[str]] = None, encoder: Optional[EC] = None, **kwargs) -> Any: + def encode_data(self, skip: Optional[List[str]] = None, encoder: Optional[BaseEncoderDecoder] = None, **kwargs) -> Any: """ Returns just the data in an EasyScience object win the format specified by an encoder. diff --git a/src/easyscience/Utils/io/xml.py b/src/easyscience/legacy/xml.py similarity index 92% rename from src/easyscience/Utils/io/xml.py rename to src/easyscience/legacy/xml.py index 7179b259..e0898409 100644 --- a/src/easyscience/Utils/io/xml.py +++ b/src/easyscience/legacy/xml.py @@ -1,12 +1,9 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project 2) & (sys.version_info.minor > 8) @@ -35,7 +32,7 @@ class XMLSerializer(BaseEncoderDecoder): def encode( self, - obj: BV, + obj: ComponentSerializer, skip: Optional[List[str]] = None, data_only: bool = False, fast: bool = False, @@ -76,7 +73,7 @@ def encode( return header + ET.tostring(block, encoding='unicode') @classmethod - def decode(cls, data: str) -> BV: + def decode(cls, data: str) -> ComponentSerializer: """ Decode an EasyScience object which has been encoded in XML format. diff --git a/src/easyscience/models/__init__.py b/src/easyscience/models/__init__.py index 47316878..de175a27 100644 --- a/src/easyscience/models/__init__.py +++ b/src/easyscience/models/__init__.py @@ -1,6 +1,3 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project np.ndarray: return np.polyval([c.value for c in self.coefficients], x) @@ -78,25 +64,3 @@ def __repr__(self): s = ' + '.join(s) return 'Polynomial({}, {})'.format(self.name, s) - -class Line(BaseObj): - m: ClassVar[Parameter] - c: ClassVar[Parameter] - - def __init__( - self, - m: Optional[Union[Parameter, float]] = None, - c: Optional[Union[Parameter, float]] = None, - ): - super(Line, self).__init__('line', m=Parameter('m', 1.0), c=Parameter('c', 0.0)) - if m is not None: - self.m = m - if c is not None: - self.c = c - - # @designate_calc_fn can be used to inject parameters into the calculation function. i.e. _m = m.value - def __call__(self, x: np.ndarray, *args, **kwargs) -> np.ndarray: - return self.m.value * x + self.c.value - - def __repr__(self): - return '{}({}, {})'.format(self.__class__.__name__, self.m, self.c) diff --git a/src/easyscience/utils/__init__.py b/src/easyscience/utils/__init__.py new file mode 100644 index 00000000..de175a27 --- /dev/null +++ b/src/easyscience/utils/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project None: +def addLoggedProp(inst: SerializerComponent, name: str, *args, **kwargs) -> None: cls = type(inst) annotations = getattr(cls, '__annotations__', False) if not hasattr(cls, '__perinstance'): @@ -32,7 +29,7 @@ def addLoggedProp(inst: BV, name: str, *args, **kwargs) -> None: setattr(cls, name, LoggedProperty(*args, **kwargs)) -def addProp(inst: BV, name: str, *args, **kwargs) -> None: +def addProp(inst: SerializerComponent, name: str, *args, **kwargs) -> None: cls = type(inst) annotations = getattr(cls, '__annotations__', False) if not hasattr(cls, '__perinstance'): @@ -46,7 +43,7 @@ def addProp(inst: BV, name: str, *args, **kwargs) -> None: setattr(cls, name, property(*args, **kwargs)) -def removeProp(inst: BV, name: str) -> None: +def removeProp(inst: SerializerComponent, name: str) -> None: cls = type(inst) if not hasattr(cls, '__perinstance'): cls = type(cls.__name__, (cls,), {'__module__': __name__}) @@ -56,7 +53,7 @@ def removeProp(inst: BV, name: str) -> None: delattr(cls, name) -def generatePath(model_obj: B, skip_first: bool = False) -> Tuple[List[int], List[str]]: +def generatePath(model_obj: BasedBase, skip_first: bool = False) -> Tuple[List[int], List[str]]: pars = model_obj.get_parameters() start_idx = 0 + int(skip_first) unique_names = [] diff --git a/src/easyscience/Utils/classUtils.py b/src/easyscience/utils/classUtils.py similarity index 94% rename from src/easyscience/Utils/classUtils.py rename to src/easyscience/utils/classUtils.py index 0b405fa5..e0c01e97 100644 --- a/src/easyscience/Utils/classUtils.py +++ b/src/easyscience/utils/classUtils.py @@ -1,9 +1,6 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project numbers.Number: return self._value @value.setter - @property_stack_deco + @property_stack def value(self, value: Union[list, np.ndarray]) -> None: """ Set the value of self. diff --git a/src/easyscience/Objects/variable/descriptor_array.py b/src/easyscience/variable/descriptor_array.py similarity index 99% rename from src/easyscience/Objects/variable/descriptor_array.py rename to src/easyscience/variable/descriptor_array.py index c9b154e5..c7a1d8ca 100644 --- a/src/easyscience/Objects/variable/descriptor_array.py +++ b/src/easyscience/variable/descriptor_array.py @@ -16,7 +16,7 @@ from scipp import Variable from easyscience.global_object.undo_redo import PropertyStack -from easyscience.global_object.undo_redo import property_stack_deco +from easyscience.global_object.undo_redo import property_stack from .descriptor_base import DescriptorBase from .descriptor_number import DescriptorNumber @@ -150,7 +150,7 @@ def value(self) -> numbers.Number: return self._array.values @value.setter - @property_stack_deco + @property_stack def value(self, value: Union[list, np.ndarray]) -> None: """ Set the value of self. Ensures the input is an array and matches the shape of the existing array. @@ -225,7 +225,7 @@ def variance(self) -> np.ndarray: return self._array.variances @variance.setter - @property_stack_deco + @property_stack def variance(self, variance: Union[list, np.ndarray]) -> None: """ Set the variance of self. Ensures the input is an array and matches the shape of the existing values. @@ -259,7 +259,7 @@ def error(self) -> Optional[np.ndarray]: return np.sqrt(self._array.variances) @error.setter - @property_stack_deco + @property_stack def error(self, error: Union[list, np.ndarray]) -> None: """ Set the standard deviation for the parameter, which updates the variances. diff --git a/src/easyscience/Objects/variable/descriptor_base.py b/src/easyscience/variable/descriptor_base.py similarity index 94% rename from src/easyscience/Objects/variable/descriptor_base.py rename to src/easyscience/variable/descriptor_base.py index b525d4f1..ab30a9f7 100644 --- a/src/easyscience/Objects/variable/descriptor_base.py +++ b/src/easyscience/variable/descriptor_base.py @@ -1,6 +1,6 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project str: return self._name @name.setter - @property_stack_deco + @property_stack def name(self, new_name: str) -> None: """ Set the name. @@ -118,7 +118,7 @@ def display_name(self) -> str: return display_name @display_name.setter - @property_stack_deco + @property_stack def display_name(self, name: str) -> None: """ Set the pretty display name. diff --git a/src/easyscience/Objects/variable/descriptor_bool.py b/src/easyscience/variable/descriptor_bool.py similarity index 95% rename from src/easyscience/Objects/variable/descriptor_bool.py rename to src/easyscience/variable/descriptor_bool.py index 768b35b1..23869172 100644 --- a/src/easyscience/Objects/variable/descriptor_bool.py +++ b/src/easyscience/variable/descriptor_bool.py @@ -3,7 +3,7 @@ from typing import Any from typing import Optional -from easyscience.global_object.undo_redo import property_stack_deco +from easyscience.global_object.undo_redo import property_stack from .descriptor_base import DescriptorBase @@ -46,7 +46,7 @@ def value(self) -> bool: return self._bool_value @value.setter - @property_stack_deco + @property_stack def value(self, value: bool) -> None: """ Set the value of self. diff --git a/src/easyscience/Objects/variable/descriptor_number.py b/src/easyscience/variable/descriptor_number.py similarity index 83% rename from src/easyscience/Objects/variable/descriptor_number.py rename to src/easyscience/variable/descriptor_number.py index cfba4a44..1b1e257b 100644 --- a/src/easyscience/Objects/variable/descriptor_number.py +++ b/src/easyscience/variable/descriptor_number.py @@ -13,11 +13,27 @@ from scipp import Variable from easyscience.global_object.undo_redo import PropertyStack -from easyscience.global_object.undo_redo import property_stack_deco +from easyscience.global_object.undo_redo import property_stack from .descriptor_base import DescriptorBase +# Why is this a decorator? Because otherwise we would need a flag on the convert_unit method to avoid +# infinite recursion. This is a bit cleaner as it avoids the need for a internal only flag on a user method. +def notify_observers(func): + """ + Decorator to notify observers of a change in the descriptor. + + :param func: Function to be decorated + :return: Decorated function + """ + def wrapper(self, *args, **kwargs): + result = func(self, *args, **kwargs) + self._notify_observers() + return result + + return wrapper + class DescriptorNumber(DescriptorBase): """ A `Descriptor` for Number values with units. The internal representation is a scipp scalar. @@ -47,6 +63,8 @@ def __init__( param parent: Parent of the descriptor .. note:: Undo/Redo functionality is implemented for the attributes `variance`, `error`, `unit` and `value`. """ + self._observers: List[DescriptorNumber] = [] + if not isinstance(value, numbers.Number) or isinstance(value, bool): raise TypeError(f'{value=} must be a number') if variance is not None: @@ -72,7 +90,8 @@ def __init__( # Call convert_unit during initialization to ensure that the unit has no numbers in it, and to ensure unit consistency. if self.unit is not None: - self.convert_unit(self._base_unit()) + self._convert_unit(self._base_unit()) + @classmethod def from_scipp(cls, name: str, full_value: Variable, **kwargs) -> DescriptorNumber: @@ -90,6 +109,33 @@ def from_scipp(cls, name: str, full_value: Variable, **kwargs) -> DescriptorNumb raise TypeError(f'{full_value=} must be a scipp scalar') return cls(name=name, value=full_value.value, unit=full_value.unit, variance=full_value.variance, **kwargs) + def _attach_observer(self, observer: DescriptorNumber) -> None: + """Attach an observer to the descriptor.""" + self._observers.append(observer) + + def _detach_observer(self, observer: DescriptorNumber) -> None: + """Detach an observer from the descriptor.""" + self._observers.remove(observer) + + def _notify_observers(self) -> None: + """Notify all observers of a change.""" + for observer in self._observers: + observer._update() + + def _validate_dependencies(self, origin=None) -> None: + """Ping all observers to check if any cyclic dependencies have been introduced. + + :param origin: Unique_name of the origin of this validation check. Used to avoid cyclic depenencies. + """ + if origin == self.unique_name: + raise RuntimeError('\n Cyclic dependency detected!\n' + + f'An update of {self.unique_name} leads to it updating itself.\n' + + 'Please check your dependencies.') + if origin is None: + origin = self.unique_name + for observer in self._observers: + observer._validate_dependencies(origin=origin) + @property def full_value(self) -> Variable: """ @@ -115,7 +161,8 @@ def value(self) -> numbers.Number: return self._scalar.value @value.setter - @property_stack_deco + @notify_observers + @property_stack def value(self, value: numbers.Number) -> None: """ Set the value of self. This should be usable for most cases. The full value can be obtained from `obj.full_value`. @@ -154,7 +201,8 @@ def variance(self) -> float: return self._scalar.variance @variance.setter - @property_stack_deco + @notify_observers + @property_stack def variance(self, variance_float: float) -> None: """ Set the variance. @@ -181,7 +229,8 @@ def error(self) -> float: return float(np.sqrt(self._scalar.variance)) @error.setter - @property_stack_deco + @notify_observers + @property_stack def error(self, value: float) -> None: """ Set the standard deviation for the parameter. @@ -198,7 +247,9 @@ def error(self, value: float) -> None: else: self._scalar.variance = None - def convert_unit(self, unit_str: str) -> None: + # When we convert units internally, we dont want to notify observers as this can cause infinite recursion. + # Therefore the convert_unit method is split into two methods, a private internal method and a public method. + def _convert_unit(self, unit_str: str) -> None: """ Convert the value from one unit system to another. @@ -229,6 +280,15 @@ def set_scalar(obj, scalar): # Update the scalar self._scalar = new_scalar + # When the user calls convert_unit, we want to notify observers of the change to propagate the change. + @notify_observers + def convert_unit(self, unit_str: str) -> None: + """ + Convert the value from one unit system to another. + + :param unit_str: New unit in string form + """ + self._convert_unit(unit_str) # Just to get return type right def __copy__(self) -> DescriptorNumber: @@ -267,11 +327,11 @@ def __add__(self, other: Union[DescriptorNumber, numbers.Number]) -> DescriptorN elif type(other) is DescriptorNumber: original_unit = other.unit try: - other.convert_unit(self.unit) + other._convert_unit(self.unit) except UnitError: raise UnitError(f'Values with units {self.unit} and {other.unit} cannot be added') from None new_value = self.full_value + other.full_value - other.convert_unit(original_unit) + other._convert_unit(original_unit) else: return NotImplemented descriptor_number = DescriptorNumber.from_scipp(name=self.name, full_value=new_value) @@ -297,11 +357,11 @@ def __sub__(self, other: Union[DescriptorNumber, numbers.Number]) -> DescriptorN elif type(other) is DescriptorNumber: original_unit = other.unit try: - other.convert_unit(self.unit) + other._convert_unit(self.unit) except UnitError: raise UnitError(f'Values with units {self.unit} and {other.unit} cannot be subtracted') from None new_value = self.full_value - other.full_value - other.convert_unit(original_unit) + other._convert_unit(original_unit) else: return NotImplemented descriptor_number = DescriptorNumber.from_scipp(name=self.name, full_value=new_value) @@ -327,7 +387,7 @@ def __mul__(self, other: Union[DescriptorNumber, numbers.Number]) -> DescriptorN else: return NotImplemented descriptor_number = DescriptorNumber.from_scipp(name=self.name, full_value=new_value) - descriptor_number.convert_unit(descriptor_number._base_unit()) + descriptor_number._convert_unit(descriptor_number._base_unit()) descriptor_number.name = descriptor_number.unique_name return descriptor_number @@ -355,7 +415,7 @@ def __truediv__(self, other: Union[DescriptorNumber, numbers.Number]) -> Descrip else: return NotImplemented descriptor_number = DescriptorNumber.from_scipp(name=self.name, full_value=new_value) - descriptor_number.convert_unit(descriptor_number._base_unit()) + descriptor_number._convert_unit(descriptor_number._base_unit()) descriptor_number.name = descriptor_number.unique_name return descriptor_number @@ -415,6 +475,9 @@ def __abs__(self) -> DescriptorNumber: return descriptor_number def _base_unit(self) -> str: + """ + Extract the base unit from the unit string by removing numeric components and scientific notation. + """ string = str(self._scalar.unit) for i, letter in enumerate(string): if letter == 'e': @@ -422,4 +485,4 @@ def _base_unit(self) -> str: return string[i:] elif letter not in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '+', '-']: return string[i:] - return '' + return '' \ No newline at end of file diff --git a/src/easyscience/Objects/variable/descriptor_str.py b/src/easyscience/variable/descriptor_str.py similarity index 94% rename from src/easyscience/Objects/variable/descriptor_str.py rename to src/easyscience/variable/descriptor_str.py index 1abe4e4e..17110166 100644 --- a/src/easyscience/Objects/variable/descriptor_str.py +++ b/src/easyscience/variable/descriptor_str.py @@ -3,7 +3,7 @@ from typing import Any from typing import Optional -from easyscience.global_object.undo_redo import property_stack_deco +from easyscience.global_object.undo_redo import property_stack from .descriptor_base import DescriptorBase @@ -45,7 +45,7 @@ def value(self) -> str: return self._string @value.setter - @property_stack_deco + @property_stack def value(self, value: str) -> None: """ Set the value of self. diff --git a/src/easyscience/Objects/variable/parameter.py b/src/easyscience/variable/parameter.py similarity index 58% rename from src/easyscience/Objects/variable/parameter.py rename to src/easyscience/variable/parameter.py index 03b3c230..c459946c 100644 --- a/src/easyscience/Objects/variable/parameter.py +++ b/src/easyscience/variable/parameter.py @@ -1,34 +1,30 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project max: raise ValueError(f'{value=} can not be greater than {max=}') - if np.isclose(min, max, rtol=1e-9, atol=0.0): raise ValueError('The min and max bounds cannot be identical. Please use fixed=True instead to fix the value.') if not isinstance(fixed, bool): raise TypeError('`fixed` must be either True or False') - + self._independent = True self._fixed = fixed # For fitting, but must be initialized before super().__init__ self._min = sc.scalar(float(min), unit=unit) self._max = sc.scalar(float(max), unit=unit) @@ -115,14 +108,192 @@ def __init__( weakref.finalize(self, self._callback.fdel) # Create additional fitting elements - self._enabled = enabled self._initial_scalar = copy.deepcopy(self._scalar) - builtin_constraint = { - # Last argument in constructor is the name of the property holding the value of the constraint - 'min': SelfConstraint(self, '>=', 'min'), - 'max': SelfConstraint(self, '<=', 'max'), - } - self._constraints = Constraints(builtin=builtin_constraint, user={}, virtual={}) + + @classmethod + def from_dependency(cls, name: str, dependency_expression: str, dependency_map: Optional[dict] = None, **kwargs) -> Parameter: # noqa: E501 + """ + Create a dependent Parameter directly from a dependency expression. + + :param name: The name of the parameter + :param dependency_expression: The dependency expression to evaluate. This should be a string which can be evaluated by the ASTEval interpreter. + :param dependency_map: A dictionary of dependency expression symbol name and dependency object pairs. This is inserted into the asteval interpreter to resolve dependencies. + :param kwargs: Additional keyword arguments to pass to the Parameter constructor. + :return: A new dependent Parameter object. + """ # noqa: E501 + parameter = cls(name=name, value=0.0, unit='', variance=0.0, min=-np.inf, max=np.inf, **kwargs) + parameter.make_dependent_on(dependency_expression=dependency_expression, dependency_map=dependency_map) + return parameter + + def _update(self) -> None: + """ + Update the parameter. This is called by the DescriptorNumbers/Parameters who have this Parameter as a dependency. + """ + if not self._independent: + # Update the value of the parameter using the dependency interpreter + temporary_parameter = self._dependency_interpreter(self._clean_dependency_string) + self._scalar.value = temporary_parameter.value + self._scalar.unit = temporary_parameter.unit + self._scalar.variance = temporary_parameter.variance + self._min.value = temporary_parameter.min if isinstance(temporary_parameter, Parameter) else temporary_parameter.value # noqa: E501 + self._max.value = temporary_parameter.max if isinstance(temporary_parameter, Parameter) else temporary_parameter.value # noqa: E501 + self._min.unit = temporary_parameter.unit + self._max.unit = temporary_parameter.unit + self._notify_observers() + else: + warnings.warn('This parameter is not dependent. It cannot be updated.') + + def make_dependent_on(self, dependency_expression: str, dependency_map: Optional[dict] = None) -> None: + """ + Make this parameter dependent on another parameter. This will overwrite the current value, unit, variance, min and max. + + How to use the dependency map: + If a parameter c has a dependency expression of 'a + b', where a and b are parameters belonging to the model class, + then the dependency map needs to have the form {'a': model.a, 'b': model.b}, where model is the model class. + I.e. the values are the actual objects, whereas the keys are how they are represented in the dependency expression. + + The dependency map is not needed if the dependency expression uses the unique names of the parameters. + Unique names in dependency expressions are defined by quotes, e.g. 'Parameter_0' or "Parameter_0" depending on the quotes used for the expression. + + :param dependency_expression: The dependency expression to evaluate. This should be a string which can be evaluated by a python interpreter. + :param dependency_map: A dictionary of dependency expression symbol name and dependency object pairs. This is inserted into the asteval interpreter to resolve dependencies. + """ # noqa: E501 + if not isinstance(dependency_expression, str): + raise TypeError('`dependency_expression` must be a string representing a valid dependency expression.') + if not (isinstance(dependency_map, dict) or dependency_map is None): + raise TypeError('`dependency_map` must be a dictionary of dependencies and their corresponding names in the dependecy expression.') # noqa: E501 + if isinstance(dependency_map, dict): + for key, value in dependency_map.items(): + if not isinstance(key, str): + raise TypeError('`dependency_map` keys must be strings representing the names of the dependencies in the dependency expression.') # noqa: E501 + if not isinstance(value, DescriptorNumber): + raise TypeError(f'`dependency_map` values must be DescriptorNumbers or Parameters. Got {type(value)} for {key}.') # noqa: E501 + + # If we're overwriting the dependency, store the old attributes + # in case we need to revert back to the old dependency + self._previous_independent = self._independent + if not self._independent: + self._previous_dependency = { + '_dependency_string': self._dependency_string, + '_dependency_map': self._dependency_map, + '_dependency_interpreter': self._dependency_interpreter, + '_clean_dependency_string': self._clean_dependency_string, + } + for dependency in self._dependency_map.values(): + dependency._detach_observer(self) + + self._independent = False + self._dependency_string = dependency_expression + self._dependency_map = dependency_map if dependency_map is not None else {} + # List of allowed python constructs for the asteval interpreter + asteval_config = {'import': False, 'importfrom': False, 'assert': False, + 'augassign': False, 'delete': False, 'if': True, + 'ifexp': True, 'for': False, 'formattedvalue': False, + 'functiondef': False, 'print': False, 'raise': False, + 'listcomp': False, 'dictcomp': False, 'setcomp': False, + 'try': False, 'while': False, 'with': False} + self._dependency_interpreter = Interpreter(config=asteval_config) + + # Process the dependency expression for unique names + try: + self._process_dependency_unique_names(self._dependency_string) + except ValueError as error: + self._revert_dependency(skip_detach=True) + raise error + + for key, value in self._dependency_map.items(): + self._dependency_interpreter.symtable[key] = value + self._dependency_interpreter.readonly_symbols.add(key) # Dont allow overwriting of the dependencies in the dependency expression # noqa: E501 + value._attach_observer(self) + # Check the dependency expression for errors + try: + dependency_result = self._dependency_interpreter.eval(self._clean_dependency_string, raise_errors=True) + except NameError as message: + self._revert_dependency() + raise NameError('\nUnknown name encountered in dependecy expression:'+ + '\n'+'\n'.join(str(message).split("\n")[1:])+ + '\nPlease check your expression or add the name to the `dependency_map`') from None + except Exception as message: + self._revert_dependency() + raise SyntaxError('\nError encountered in dependecy expression:'+ + '\n'+'\n'.join(str(message).split("\n")[1:])+ + '\nPlease check your expression') from None + if not isinstance(dependency_result, DescriptorNumber): + error_string = self._dependency_string + self._revert_dependency() + raise TypeError(f'The dependency expression: "{error_string}" returned a {type(dependency_result)}, it should return a Parameter or DescriptorNumber.') # noqa: E501 + # Check for cyclic dependencies + try: + self._validate_dependencies() + except RuntimeError as error: + self._revert_dependency() + raise error + # Update the parameter with the dependency result + self._fixed = False + self._update() + + def make_independent(self) -> None: + """ + Make this parameter independent. + This will remove the dependency expression, the dependency map and the dependency interpreter. + + :return: None + """ + if not self._independent: + for dependency in self._dependency_map.values(): + dependency._detach_observer(self) + self._independent = True + del self._dependency_map + del self._dependency_interpreter + del self._dependency_string + del self._clean_dependency_string + else: + raise AttributeError('This parameter is already independent.') + + @property + def independent(self) -> bool: + """ + Is the parameter independent? + + :return: True = independent, False = dependent + """ + return self._independent + + @independent.setter + def independent(self, value: bool) -> None: + raise AttributeError('This property is read-only. Use `make_independent` and `make_dependent_on` to change the state of the parameter.') # noqa: E501 + + @property + def dependency_expression(self) -> str: + """ + Get the dependency expression of this parameter. + + :return: The dependency expression of this parameter. + """ + if not self._independent: + return self._dependency_string + else: + raise AttributeError('This parameter is independent. It has no dependency expression.') + + @dependency_expression.setter + def dependency_expression(self, new_expression: str) -> None: + raise AttributeError('Dependency expression is read-only. Use `make_dependent_on` to change the dependency expression.') # noqa: E501 + + @property + def dependency_map(self) -> Dict[str, DescriptorNumber]: + """ + Get the dependency map of this parameter. + + :return: The dependency map of this parameter. + """ + if not self._independent: + return self._dependency_map + else: + raise AttributeError('This parameter is independent. It has no dependency map.') + + @dependency_map.setter + def dependency_map(self, new_map: Dict[str, DescriptorNumber]) -> None: + raise AttributeError('Dependency map is read-only. Use `make_dependent_on` to change the dependency map.') @property def value_no_call_back(self) -> numbers.Number: @@ -167,57 +338,79 @@ def value(self) -> numbers.Number: return self._scalar.value @value.setter - @property_stack_deco + @property_stack def value(self, value: numbers.Number) -> None: """ Set the value of self. This only updates the value of the scipp scalar. :param value: New value of self """ - if not self.enabled: - if global_object.debug: - raise CoreSetException(f'{str(self)} is not enabled.') - return + if self._independent: + if not isinstance(value, numbers.Number): + raise TypeError(f'{value=} must be a number') + + value = float(value) + if value < self._min.value: + value = self._min.value + if value > self._max.value: + value = self._max.value - if not isinstance(value, numbers.Number) or isinstance(value, bool): - raise TypeError(f'{value=} must be a number') + self._scalar.value = value - # Need to set the value for constraints to be functional - self._scalar.value = float(value) - # if self._callback.fset is not None: - # self._callback.fset(self._scalar.value) + if self._callback.fset is not None: + self._callback.fset(self._scalar.value) - # Deals with min/max - value = self._constraint_runner(self.builtin_constraints, self._scalar.value) + # Notify observers of the change + self._notify_observers() + else: + raise AttributeError("This is a dependent parameter, its value cannot be set directly.") - # Deals with user constraints - # Changes should not be registrered in the undo/redo stack - stack_state = global_object.stack.enabled - if stack_state: - global_object.stack.force_state(False) - try: - value = self._constraint_runner(self.user_constraints, value) - finally: - global_object.stack.force_state(stack_state) + @DescriptorNumber.variance.setter + def variance(self, variance_float: float) -> None: + """ + Set the variance. - value = self._constraint_runner(self._constraints.virtual, value) + :param variance_float: Variance as a float + """ + if self._independent: + DescriptorNumber.variance.fset(self, variance_float) + else: + raise AttributeError("This is a dependent parameter, its variance cannot be set directly.") - self._scalar.value = float(value) - if self._callback.fset is not None: - self._callback.fset(self._scalar.value) + @DescriptorNumber.error.setter + def error(self, value: float) -> None: + """ + Set the standard deviation for the parameter. - def convert_unit(self, unit_str: str) -> None: + :param value: New error value + """ + if self._independent: + DescriptorNumber.error.fset(self, value) + else: + raise AttributeError("This is a dependent parameter, its error cannot be set directly.") + + def _convert_unit(self, unit_str: str) -> None: """ Perform unit conversion. The value, max and min can change on unit change. :param new_unit: new unit :return: None """ - super().convert_unit(unit_str) + super()._convert_unit(unit_str=unit_str) new_unit = sc.Unit(unit_str) # unit_str is tested in super method self._min = self._min.to(unit=new_unit) self._max = self._max.to(unit=new_unit) + @notify_observers + def convert_unit(self, unit_str: str) -> None: + """ + Perform unit conversion. The value, max and min can change on unit change. + + :param new_unit: new unit + :return: None + """ + self._convert_unit(unit_str) + @property def min(self) -> numbers.Number: """ @@ -228,7 +421,7 @@ def min(self) -> numbers.Number: return self._min.value @min.setter - @property_stack_deco + @property_stack def min(self, min_value: numbers.Number) -> None: """ Set the minimum value for fitting. @@ -237,14 +430,18 @@ def min(self, min_value: numbers.Number) -> None: :param min_value: new minimum value :return: None """ - if not isinstance(min_value, numbers.Number): - raise TypeError('`min` must be a number') - if np.isclose(min_value, self._max.value, rtol=1e-9, atol=0.0): - raise ValueError('The min and max bounds cannot be identical. Please use fixed=True instead to fix the value.') - if min_value <= self.value: - self._min.value = min_value + if self._independent: + if not isinstance(min_value, numbers.Number): + raise TypeError('`min` must be a number') + if np.isclose(min_value, self._max.value, rtol=1e-9, atol=0.0): + raise ValueError('The min and max bounds cannot be identical. Please use fixed=True instead to fix the value.') + if min_value <= self.value: + self._min.value = min_value + else: + raise ValueError(f'The current value ({self.value}) is smaller than the desired min value ({min_value}).') + self._notify_observers() else: - raise ValueError(f'The current value ({self.value}) is smaller than the desired min value ({min_value}).') + raise AttributeError("This is a dependent parameter, its minimum value cannot be set directly.") @property def max(self) -> numbers.Number: @@ -256,7 +453,7 @@ def max(self) -> numbers.Number: return self._max.value @max.setter - @property_stack_deco + @property_stack def max(self, max_value: numbers.Number) -> None: """ Get the maximum value for fitting. @@ -265,14 +462,18 @@ def max(self, max_value: numbers.Number) -> None: :param max_value: new maximum value :return: None """ - if not isinstance(max_value, numbers.Number): - raise TypeError('`max` must be a number') - if np.isclose(max_value, self._min.value, rtol=1e-9, atol=0.0): - raise ValueError('The min and max bounds cannot be identical. Please use fixed=True instead to fix the value.') - if max_value >= self.value: - self._max.value = max_value + if self._independent: + if not isinstance(max_value, numbers.Number): + raise TypeError('`max` must be a number') + if np.isclose(max_value, self._min.value, rtol=1e-9, atol=0.0): + raise ValueError('The min and max bounds cannot be identical. Please use fixed=True instead to fix the value.') + if max_value >= self.value: + self._max.value = max_value + else: + raise ValueError(f'The current value ({self.value}) is greater than the desired max value ({max_value}).') + self._notify_observers() else: - raise ValueError(f'The current value ({self.value}) is greater than the desired max value ({max_value}).') + raise AttributeError("This is a dependent parameter, its maximum value cannot be set directly.") @property def fixed(self) -> bool: @@ -284,7 +485,7 @@ def fixed(self) -> bool: return self._fixed @fixed.setter - @property_stack_deco + @property_stack def fixed(self, fixed: bool) -> None: """ Change the parameter vary while fitting state. @@ -292,17 +493,17 @@ def fixed(self, fixed: bool) -> None: :param fixed: True = fixed, False = can vary """ - if not self.enabled: - if global_object.stack.enabled: - # Remove the recorded change from the stack - global_object.stack.pop() - if global_object.debug: - raise CoreSetException(f'{str(self)} is not enabled.') - return if not isinstance(fixed, bool): raise ValueError(f'{fixed=} must be a boolean. Got {type(fixed)}') - self._fixed = fixed + if self._independent: + self._fixed = fixed + else: + if self._global_object.stack.enabled: + # Remove the recorded change from the stack + global_object.stack.pop() + raise AttributeError("This is a dependent parameter, dependent parameters cannot be fixed.") + # Is this alias really needed? @property def free(self) -> bool: return not self.fixed @@ -311,112 +512,47 @@ def free(self) -> bool: def free(self, value: bool) -> None: self.fixed = not value - @property - def bounds(self) -> Tuple[numbers.Number, numbers.Number]: - """ - Get the bounds of the parameter. - - :return: Tuple of the parameters minimum and maximum values - """ - return self.min, self.max - @bounds.setter - def bounds(self, new_bound: Tuple[numbers.Number, numbers.Number]) -> None: - """ - Set the bounds of the parameter. *This will also enable the parameter*. - - :param new_bound: New bounds. This should be a tuple of (min, max). - """ - old_min = self.min - old_max = self.max - new_min, new_max = new_bound - - # Begin macro operation for undo/redo - close_macro = False - if self._global_object.stack.enabled: - self._global_object.stack.beginMacro('Setting bounds') - close_macro = True - - try: - # Update bounds - self.min = new_min - self.max = new_max - except ValueError: - # Rollback on failure - self.min = old_min - self.max = old_max - if close_macro: - self._global_object.stack.endMacro() - raise ValueError(f'Current parameter value: {self._scalar.value} must be within {new_bound=}') - - # Enable the parameter if needed - if not self.enabled: - self.enabled = True - # Free parameter if needed - if self.fixed: - self.fixed = False - - # End macro operation - if close_macro: - self._global_object.stack.endMacro() - - @property - def builtin_constraints(self) -> Dict[str, SelfConstraint]: - """ - Get the built in constrains of the object. Typically these are the min/max - - :return: Dictionary of constraints which are built into the system + def _revert_dependency(self, skip_detach=False) -> None: """ - return MappingProxyType(self._constraints.builtin) - - @property - def user_constraints(self) -> Dict[str, ConstraintBase]: + Revert the dependency to the old dependency. This is used when an error is raised during setting the dependency. """ - Get the user specified constrains of the object. - - :return: Dictionary of constraints which are user supplied - """ - return self._constraints.user - - @user_constraints.setter - def user_constraints(self, constraints_dict: Dict[str, ConstraintBase]) -> None: - self._constraints.user = constraints_dict - - def _constraint_runner( - self, - this_constraint_type, - value: numbers.Number, - ) -> float: - for constraint in this_constraint_type.values(): - if constraint.external: - constraint() - continue - - constained_value = constraint(no_set=True) - if constained_value != value: - if global_object.debug: - print(f'Constraint `{constraint}` has been applied') - self._scalar.value = constained_value - value = constained_value - return value - - @property - def enabled(self) -> bool: - """ - Logical property to see if the objects value can be directly set. - - :return: Can the objects value be set - """ - return self._enabled - - @enabled.setter - @property_stack_deco - def enabled(self, value: bool) -> None: - """ - Enable and disable the direct setting of an objects value field. - - :param value: True - objects value can be set, False - the opposite - """ - self._enabled = value + if self._previous_independent is True: + self.make_independent() + else: + if not skip_detach: + for dependency in self._dependency_map.values(): + dependency._detach_observer(self) + for key, value in self._previous_dependency.items(): + setattr(self, key, value) + for dependency in self._dependency_map.values(): + dependency._attach_observer(self) + del self._previous_dependency + del self._previous_independent + + def _process_dependency_unique_names(self, dependency_expression: str): + """ + Add the unique names of the parameters to the ASTEval interpreter. This is used to evaluate the dependency expression. + + :param dependency_expression: The dependency expression to be evaluated + """ + # Get the unique_names from the expression string regardless of the quotes used + inputted_unique_names = re.findall("(\'.+?\')", dependency_expression) + inputted_unique_names += re.findall('(\".+?\")', dependency_expression) + + clean_dependency_string = dependency_expression + existing_unique_names = self._global_object.map.vertices() + # Add the unique names of the parameters to the ASTEVAL interpreter + for name in inputted_unique_names: + stripped_name = name.strip("'\"") + if stripped_name not in existing_unique_names: + raise ValueError(f'A Parameter with unique_name {stripped_name} does not exist. Please check your dependency expression.') # noqa: E501 + dependent_parameter = self._global_object.map.get_item_by_key(stripped_name) + if isinstance(dependent_parameter, DescriptorNumber): + self._dependency_map['__'+stripped_name+'__'] = dependent_parameter + clean_dependency_string = clean_dependency_string.replace(name, '__'+stripped_name+'__') + else: + raise ValueError(f'The object with unique_name {stripped_name} is not a Parameter or DescriptorNumber. Please check your dependency expression.') # noqa: E501 + self._clean_dependency_string = clean_dependency_string def __copy__(self) -> Parameter: new_obj = super().__copy__() @@ -450,13 +586,13 @@ def __add__(self, other: Union[DescriptorNumber, Parameter, numbers.Number]) -> elif isinstance(other, DescriptorNumber): # Parameter inherits from DescriptorNumber and is also handled here other_unit = other.unit try: - other.convert_unit(self.unit) + other._convert_unit(self.unit) except UnitError: raise UnitError(f'Values with units {self.unit} and {other.unit} cannot be added') from None new_full_value = self.full_value + other.full_value min_value = self.min + other.min if isinstance(other, Parameter) else self.min + other.value max_value = self.max + other.max if isinstance(other, Parameter) else self.max + other.value - other.convert_unit(other_unit) + other._convert_unit(other_unit) else: return NotImplemented parameter = Parameter.from_scipp(name=self.name, full_value=new_full_value, min=min_value, max=max_value) @@ -473,13 +609,13 @@ def __radd__(self, other: Union[DescriptorNumber, numbers.Number]) -> Parameter: elif isinstance(other, DescriptorNumber): # Parameter inherits from DescriptorNumber and is also handled here original_unit = self.unit try: - self.convert_unit(other.unit) + self._convert_unit(other.unit) except UnitError: raise UnitError(f'Values with units {other.unit} and {self.unit} cannot be added') from None new_full_value = self.full_value + other.full_value min_value = self.min + other.value max_value = self.max + other.value - self.convert_unit(original_unit) + self._convert_unit(original_unit) else: return NotImplemented parameter = Parameter.from_scipp(name=self.name, full_value=new_full_value, min=min_value, max=max_value) @@ -496,7 +632,7 @@ def __sub__(self, other: Union[DescriptorNumber, Parameter, numbers.Number]) -> elif isinstance(other, DescriptorNumber): # Parameter inherits from DescriptorNumber and is also handled here other_unit = other.unit try: - other.convert_unit(self.unit) + other._convert_unit(self.unit) except UnitError: raise UnitError(f'Values with units {self.unit} and {other.unit} cannot be subtracted') from None new_full_value = self.full_value - other.full_value @@ -506,7 +642,7 @@ def __sub__(self, other: Union[DescriptorNumber, Parameter, numbers.Number]) -> else: min_value = self.min - other.value max_value = self.max - other.value - other.convert_unit(other_unit) + other._convert_unit(other_unit) else: return NotImplemented parameter = Parameter.from_scipp(name=self.name, full_value=new_full_value, min=min_value, max=max_value) @@ -523,13 +659,13 @@ def __rsub__(self, other: Union[DescriptorNumber, numbers.Number]) -> Parameter: elif isinstance(other, DescriptorNumber): # Parameter inherits from DescriptorNumber and is also handled here original_unit = self.unit try: - self.convert_unit(other.unit) + self._convert_unit(other.unit) except UnitError: raise UnitError(f'Values with units {other.unit} and {self.unit} cannot be subtracted') from None new_full_value = other.full_value - self.full_value min_value = other.value - self.max max_value = other.value - self.min - self.convert_unit(original_unit) + self._convert_unit(original_unit) else: return NotImplemented parameter = Parameter.from_scipp(name=self.name, full_value=new_full_value, min=min_value, max=max_value) @@ -573,7 +709,7 @@ def __mul__(self, other: Union[DescriptorNumber, Parameter, numbers.Number]) -> min_value = min(combinations) max_value = max(combinations) parameter = Parameter.from_scipp(name=self.name, full_value=new_full_value, min=min_value, max=max_value) - parameter.convert_unit(parameter._base_unit()) + parameter._convert_unit(parameter._base_unit()) parameter.name = parameter.unique_name return parameter @@ -597,7 +733,7 @@ def __rmul__(self, other: Union[DescriptorNumber, numbers.Number]) -> Parameter: min_value = min(combinations) max_value = max(combinations) parameter = Parameter.from_scipp(name=self.name, full_value=new_full_value, min=min_value, max=max_value) - parameter.convert_unit(parameter._base_unit()) + parameter._convert_unit(parameter._base_unit()) parameter.name = parameter.unique_name return parameter @@ -639,7 +775,7 @@ def __truediv__(self, other: Union[DescriptorNumber, Parameter, numbers.Number]) min_value = min(combinations) max_value = max(combinations) parameter = Parameter.from_scipp(name=self.name, full_value=new_full_value, min=min_value, max=max_value) - parameter.convert_unit(parameter._base_unit()) + parameter._convert_unit(parameter._base_unit()) parameter.name = parameter.unique_name return parameter @@ -680,7 +816,7 @@ def __rtruediv__(self, other: Union[DescriptorNumber, numbers.Number]) -> Parame min_value = min(combinations) max_value = max(combinations) parameter = Parameter.from_scipp(name=self.name, full_value=new_full_value, min=min_value, max=max_value) - parameter.convert_unit(parameter._base_unit()) + parameter._convert_unit(parameter._base_unit()) parameter.name = parameter.unique_name self.value = original_self return parameter diff --git a/tests/__init__.py b/tests/__init__.py index a4ab5234..462d95af 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +1,3 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project Tuple[List[Parameter], List[int]]: - mock_callback = MagicMock() - mock_callback.fget = MagicMock(return_value=-10) - return [Parameter("a", 1, callback=mock_callback), Parameter("b", 2, callback=mock_callback)], [1, 2] - - -@pytest.fixture -def threePars(twoPars) -> Tuple[List[Parameter], List[int]]: - ps, vs = twoPars - ps.append(Parameter("c", 3)) - vs.append(3) - return ps, vs - - -def test_NumericConstraints_Equals(twoPars): - - value = 1 - - # Should skip - c = NumericConstraint(twoPars[0][0], "==", value) - c() - assert twoPars[0][0].value_no_call_back == twoPars[1][0] - - # Should update to new value - c = NumericConstraint(twoPars[0][1], "==", value) - c() - assert twoPars[0][1].value_no_call_back == value - - -def test_NumericConstraints_Greater(twoPars): - value = 1.5 - - # Should update to new value - c = NumericConstraint(twoPars[0][0], ">", value) - c() - assert twoPars[0][0].value_no_call_back == value - - # Should skip - c = NumericConstraint(twoPars[0][1], ">", value) - c() - assert twoPars[0][1].value_no_call_back == twoPars[1][1] - - -def test_NumericConstraints_Less(twoPars): - value = 1.5 - - # Should skip - c = NumericConstraint(twoPars[0][0], "<", value) - c() - assert twoPars[0][0].value_no_call_back == twoPars[1][0] - - # Should update to new value - c = NumericConstraint(twoPars[0][1], "<", value) - c() - assert twoPars[0][1].value_no_call_back == value - - -@pytest.mark.parametrize("multiplication_factor", [None, 1, 2, 3, 4.5]) -def test_ObjConstraintMultiply(twoPars, multiplication_factor): - if multiplication_factor is None: - multiplication_factor = 1 - operator_str = "" - else: - operator_str = f"{multiplication_factor}*" - c = ObjConstraint(twoPars[0][0], operator_str, twoPars[0][1]) - c() - assert twoPars[0][0].value_no_call_back == multiplication_factor * twoPars[1][1] - - -@pytest.mark.parametrize("division_factor", [1, 2, 3, 4.5]) -def test_ObjConstraintDivide(twoPars, division_factor): - operator_str = f"{division_factor}/" - c = ObjConstraint(twoPars[0][0], operator_str, twoPars[0][1]) - c() - assert twoPars[0][0].value_no_call_back == division_factor / twoPars[1][1] - - -def test_ObjConstraint_Multiple(threePars): - - p0 = threePars[0][0] - p1 = threePars[0][1] - p2 = threePars[0][2] - - value = 1.5 - - p0.user_constraints["num_1"] = ObjConstraint(p1, "", p0) - p0.user_constraints["num_2"] = ObjConstraint(p2, "", p0) - - p0.value = value - assert p0.value_no_call_back == value - assert p1.value_no_call_back == value - assert p2.value_no_call_back == value - - -def test_ConstraintEnable_Disable(twoPars): - - assert twoPars[0][0].enabled - assert twoPars[0][1].enabled - - c = ObjConstraint(twoPars[0][0], "", twoPars[0][1]) - twoPars[0][0].user_constraints["num_1"] = c - - assert c.enabled - assert twoPars[0][1].enabled - assert not twoPars[0][0].enabled - - c.enabled = False - assert not c.enabled - assert twoPars[0][1].enabled - assert twoPars[0][0].enabled - - c.enabled = True - assert c.enabled - assert twoPars[0][1].enabled - assert not twoPars[0][0].enabled diff --git a/tests/unit_tests/Objects/__init__.py b/tests/unit_tests/Objects/__init__.py deleted file mode 100644 index 22e236a6..00000000 --- a/tests/unit_tests/Objects/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project " - d = DescriptorNumber("test", 1, unit="cm") - assert repr(d) == f"<{d.__class__.__name__} 'test': 1.0000 cm>" - - -def test_descriptor_number_as_dict(): - d = DescriptorNumber("test", 1) - result = d.as_dict() - expected = { - "@module": DescriptorNumber.__module__, - "@class": DescriptorNumber.__name__, - "@version": easyscience.__version__, - "name": "test", - "value": 1, - "unit": "dimensionless", - "description": "", - "url": "", - "display_name": "test", - "callback": None, - } - for key in expected.keys(): - if key == "callback": - continue - assert result[key] == expected[key] - - -@pytest.mark.parametrize( - "reference, constructor", - ( - [ - { - "@module": DescriptorBool.__module__, - "@class": DescriptorBool.__name__, - "@version": easyscience.__version__, - "name": "test", - "value": False, - "description": "", - "url": "", - "display_name": "test", - }, - DescriptorBool, - ], - [ - { - "@module": DescriptorNumber.__module__, - "@class": DescriptorNumber.__name__, - "@version": easyscience.__version__, - "name": "test", - "value": 1, - "unit": "dimensionless", - "variance": 0.0, - "description": "", - "url": "", - "display_name": "test", - }, - DescriptorNumber, - ], - [ - { - "@module": DescriptorStr.__module__, - "@class": DescriptorStr.__name__, - "@version": easyscience.__version__, - "name": "test", - "value": "string", - "description": "", - "url": "", - "display_name": "test", - }, - DescriptorStr, - ], - ), - ids=["DescriptorBool", "DescriptorNumber", "DescriptorStr"], -) -def test_item_from_dict(reference, constructor): - d = constructor.from_dict(reference) - for key, item in reference.items(): - if key.startswith("@"): - continue - obtained = getattr(d, key) - assert obtained == item - - -@pytest.mark.parametrize("value", ("This is ", "a fun ", "test")) -def test_parameter_display_name(value): - p = DescriptorNumber("test", 1, display_name=value) - assert p.display_name == value - - -def test_item_boolean_value(): - item = DescriptorBool("test", True) - assert item.value is True - item.value = False - assert item.value is False - - item = DescriptorBool("test", False) - assert item.value is False - item.value = True - assert item.value is True diff --git a/tests/unit_tests/Objects/variable/test_parameter_from_legacy.py b/tests/unit_tests/Objects/variable/test_parameter_from_legacy.py deleted file mode 100644 index f4dcd2ea..00000000 --- a/tests/unit_tests/Objects/variable/test_parameter_from_legacy.py +++ /dev/null @@ -1,424 +0,0 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project " - d = Parameter("test", 1, unit="cm") - assert repr(d) == f"<{d.__class__.__name__} 'test': 1.0000 cm, bounds=[-inf:inf]>" - d = Parameter("test", 1, variance=0.1) - assert repr(d) == f"<{d.__class__.__name__} 'test': 1.0000 ± 0.3162, bounds=[-inf:inf]>" - - d = Parameter("test", 1, fixed=True) - assert ( - repr(d) - == f"<{d.__class__.__name__} 'test': 1.0000 (fixed), bounds=[-inf:inf]>" - ) - d = Parameter("test", 1, unit="cm", variance=0.1, fixed=True) - assert ( - repr(d) - == f"<{d.__class__.__name__} 'test': 1.0000 ± 0.3162 cm (fixed), bounds=[-inf:inf]>" - ) - - -def test_parameter_as_dict(): - d = Parameter("test", 1) - result = d.as_dict() - expected = { - "@module": Parameter.__module__, - "@class": Parameter.__name__, - "@version": easyscience.__version__, - "name": "test", - "value": 1.0, - "variance": 0.0, - "min": -np.inf, - "max": np.inf, - "fixed": False, - "unit": "dimensionless", - } - for key in expected.keys(): - assert result[key] == expected[key] - - # Check that additional arguments work - d = Parameter("test", 1, unit="km", url="https://www.boo.com") - result = d.as_dict() - expected = { - "@module": Parameter.__module__, - "@class": Parameter.__name__, - "@version": easyscience.__version__, - "name": "test", - "unit": "km", - "value": 1.0, - "variance": 0.0, - "min": -np.inf, - "max": np.inf, - "fixed": False, - "url": "https://www.boo.com", - } - for key in expected.keys(): - assert result[key] == expected[key] - - -def test_item_from_dict(): - reference = { - "@module": Parameter.__module__, - "@class": Parameter.__name__, - "@version": easyscience.__version__, - "name": "test", - "unit": "km", - "value": 1.0, - "variance": 0.0, - "min": -np.inf, - "max": np.inf, - "fixed": False, - "url": "https://www.boo.com", - } - constructor = Parameter - d = constructor.from_dict(reference) - for key, item in reference.items(): - if key == "callback" or key.startswith("@"): - continue - obtained = getattr(d, key) - assert obtained == item - - -@pytest.mark.parametrize( - "construct", - ( - { - "@module": Parameter.__module__, - "@class": Parameter.__name__, - "@version": easyscience.__version__, - "name": "test", - "unit": "km", - "value": 1.0, - "variance": 0.0, - "min": -np.inf, - "max": np.inf, - "fixed": False, - "url": "https://www.boo.com", - }, - ), - ids=["Parameter"], -) -def test_item_from_Decoder(construct): - - from easyscience.Utils.io.dict import DictSerializer - - d = DictSerializer().decode(construct) - assert d.__class__.__name__ == construct["@class"] - for key, item in construct.items(): - if key == "callback" or key.startswith("@"): - continue - obtained = getattr(d, key) - assert obtained == item - - -@pytest.mark.parametrize("value", (-np.inf, 0, 1.0, 2147483648, np.inf)) -def test_parameter_min(value): - d = Parameter("test", -0.1) - if d.value < value: - with pytest.raises(ValueError): - d.min = value - else: - d.min = value - assert d.min == value - - -@pytest.mark.parametrize("value", [-np.inf, 0, 1.1, 2147483648, np.inf]) -def test_parameter_max(value): - d = Parameter("test", 2147483649) - if d.value > value: - with pytest.raises(ValueError): - d.max = value - else: - d.max = value - assert d.max == value - - -@pytest.mark.parametrize("value", [True, False, 5]) -def test_parameter_fixed(value): - d = Parameter("test", -np.inf) - if isinstance(value, bool): - d.fixed = value - assert d.fixed == value - else: - with pytest.raises(ValueError): - d.fixed = value - - -@pytest.mark.parametrize("value", (-np.inf, -0.1, 0, 1.0, 2147483648, np.inf)) -def test_parameter_error(value): - d = Parameter("test", 1) - if value >= 0: - d.error = value - assert d.error == value - else: - with pytest.raises(ValueError): - d.error = value - - -def _generate_advanced_inputs(): - temp = _generate_inputs() - # These will be the optional parameters - advanced = {"variance": 1.0, "min": -0.1, "max": 2147483648, "fixed": False} - advanced_result = { - "variance": {"name": "variance", "value": advanced["variance"]}, - "min": {"name": "min", "value": advanced["min"]}, - "max": {"name": "max", "value": advanced["max"]}, - "fixed": {"name": "fixed", "value": advanced["fixed"]}, - } - - def create_entry(base, key, value, ref, ref_key=None): - this_temp = deepcopy(base) - for item in base: - test, res = item - new_opt = deepcopy(test[1]) - new_res = deepcopy(res) - if ref_key is None: - ref_key = key - new_res[ref_key] = ref - new_opt[key] = value - this_temp.append(([test[0], new_opt], new_res)) - return this_temp - - for add_opt in advanced.keys(): - if isinstance(advanced[add_opt], list): - for idx, item in enumerate(advanced[add_opt]): - temp = create_entry( - temp, - add_opt, - item, - advanced_result[add_opt]["value"][idx], - ref_key=advanced_result[add_opt]["name"], - ) - else: - temp = create_entry( - temp, - add_opt, - advanced[add_opt], - advanced_result[add_opt]["value"], - ref_key=advanced_result[add_opt]["name"], - ) - return temp - - -@pytest.mark.parametrize("element, expected", _generate_advanced_inputs()) -def test_parameter_advanced_creation(element, expected): - if len(element[0]) > 0: - value = element[0][1] - else: - value = element[1]["value"] - if "min" in element[1].keys(): - if element[1]["min"] > value: - with pytest.raises(ValueError): - d = Parameter(*element[0], **element[1]) - elif "max" in element[1].keys(): - if element[1]["max"] < value: - with pytest.raises(ValueError): - d = Parameter(*element[0], **element[1]) - else: - d = Parameter(*element[0], **element[1]) - for field in expected.keys(): - ref = expected[field] - obtained = getattr(d, field) - assert obtained == ref - - -@pytest.mark.parametrize("value", ("This is ", "a fun ", "test")) -def test_parameter_display_name(value): - p = Parameter("test", 1, display_name=value) - assert p.display_name == value - - -@pytest.mark.parametrize("value", (True, False)) -def test_parameter_bounds(value): - for fixed in (True, False): - p = Parameter("test", 1, enabled=value, fixed=fixed) - assert p.min == -np.inf - assert p.max == np.inf - assert p.fixed == fixed - assert p.bounds == (-np.inf, np.inf) - - p.bounds = (0, 2) - assert p.min == 0 - assert p.max == 2 - assert p.bounds == (0, 2) - assert p.enabled is True - assert p.fixed is False \ No newline at end of file diff --git a/tests/unit_tests/__init__.py b/tests/unit_tests/__init__.py index 22e236a6..462d95af 100644 --- a/tests/unit_tests/__init__.py +++ b/tests/unit_tests/__init__.py @@ -1,6 +1,3 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project None: # When minimizer._original_fit_function = MagicMock(return_value='fit_function_result') - mock_fit_constraint = MagicMock() - minimizer.fit_constraints = MagicMock(return_value=[mock_fit_constraint]) - minimizer._object = MagicMock() mock_parm_1 = MagicMock(Parameter) mock_parm_1.unique_name = 'mock_parm_1' @@ -148,7 +144,6 @@ def test_generate_fit_function(self, minimizer: MinimizerBase) -> None: # Expect assert 'fit_function_result' == fit_function_result - mock_fit_constraint.assert_called_once_with() minimizer._original_fit_function.assert_called_once_with([10.0]) assert minimizer._cached_pars['mock_parm_1'] == mock_parm_1 assert minimizer._cached_pars['mock_parm_2'] == mock_parm_2 diff --git a/tests/unit_tests/Fitting/minimizers/test_minimizer_bumps.py b/tests/unit_tests/fitting/minimizers/test_minimizer_bumps.py similarity index 100% rename from tests/unit_tests/Fitting/minimizers/test_minimizer_bumps.py rename to tests/unit_tests/fitting/minimizers/test_minimizer_bumps.py diff --git a/tests/unit_tests/Fitting/minimizers/test_minimizer_dfo.py b/tests/unit_tests/fitting/minimizers/test_minimizer_dfo.py similarity index 97% rename from tests/unit_tests/Fitting/minimizers/test_minimizer_dfo.py rename to tests/unit_tests/fitting/minimizers/test_minimizer_dfo.py index 8c39b8a5..7d8ca3fe 100644 --- a/tests/unit_tests/Fitting/minimizers/test_minimizer_dfo.py +++ b/tests/unit_tests/fitting/minimizers/test_minimizer_dfo.py @@ -4,7 +4,7 @@ import numpy as np import easyscience.fitting.minimizers.minimizer_dfo -from easyscience.Objects.variable import Parameter +from easyscience.variable import Parameter from easyscience.fitting.minimizers.minimizer_dfo import DFO from easyscience.fitting.minimizers.utils import FitError @@ -72,9 +72,6 @@ def test_generate_fit_function(self, minimizer: DFO) -> None: # When minimizer._original_fit_function = MagicMock(return_value='fit_function_result') - mock_fit_constraint = MagicMock() - minimizer.fit_constraints = MagicMock(return_value=[mock_fit_constraint]) - minimizer._object = MagicMock() mock_parm_1 = MagicMock() mock_parm_1.unique_name = 'mock_parm_1' @@ -92,7 +89,6 @@ def test_generate_fit_function(self, minimizer: DFO) -> None: # Expect assert 'fit_function_result' == fit_function_result - mock_fit_constraint.assert_called_once_with() minimizer._original_fit_function.assert_called_once_with([10.0]) assert minimizer._cached_pars['mock_parm_1'] == mock_parm_1 assert minimizer._cached_pars['mock_parm_2'] == mock_parm_2 diff --git a/tests/unit_tests/Fitting/minimizers/test_minimizer_lmfit.py b/tests/unit_tests/fitting/minimizers/test_minimizer_lmfit.py similarity index 99% rename from tests/unit_tests/Fitting/minimizers/test_minimizer_lmfit.py rename to tests/unit_tests/fitting/minimizers/test_minimizer_lmfit.py index f981436e..8a187709 100644 --- a/tests/unit_tests/Fitting/minimizers/test_minimizer_lmfit.py +++ b/tests/unit_tests/fitting/minimizers/test_minimizer_lmfit.py @@ -4,7 +4,7 @@ import easyscience.fitting.minimizers.minimizer_lmfit from easyscience.fitting.minimizers.minimizer_lmfit import LMFit -from easyscience.Objects.variable import Parameter +from easyscience import Parameter from lmfit import Parameter as LMParameter from easyscience.fitting.minimizers.utils import FitError diff --git a/tests/unit_tests/Fitting/test_fitter.py b/tests/unit_tests/fitting/test_fitter.py similarity index 83% rename from tests/unit_tests/Fitting/test_fitter.py rename to tests/unit_tests/fitting/test_fitter.py index 63783c17..dede119c 100644 --- a/tests/unit_tests/Fitting/test_fitter.py +++ b/tests/unit_tests/fitting/test_fitter.py @@ -3,8 +3,8 @@ import pytest import numpy as np import easyscience.fitting.fitter -from easyscience.fitting.fitter import Fitter -from easyscience.fitting.available_minimizers import AvailableMinimizers +from easyscience import Fitter +from easyscience import AvailableMinimizers class TestFitter(): @@ -24,42 +24,6 @@ def test_constructor(self, fitter: Fitter): assert fitter._minimizer is None fitter._update_minimizer.assert_called_once_with(AvailableMinimizers.LMFit_leastsq) - def test_fit_constraints(self, fitter: Fitter): - # When - mock_minimizer = MagicMock() - mock_minimizer.fit_constraints = MagicMock(return_value='constraints') - fitter._minimizer = mock_minimizer - - # Then - constraints = fitter.fit_constraints() - - # Expect - assert constraints == 'constraints' - - def test_add_fit_constraint(self, fitter: Fitter): - # When - mock_minimizer = MagicMock() - mock_minimizer.add_fit_constraint = MagicMock() - fitter._minimizer = mock_minimizer - - # Then - fitter.add_fit_constraint('constraints') - - # Expect - mock_minimizer.add_fit_constraint.assert_called_once_with('constraints') - - def test_remove_fit_constraint(self, fitter: Fitter): - # When - mock_minimizer = MagicMock() - mock_minimizer.remove_fit_constraint = MagicMock() - fitter._minimizer = mock_minimizer - - # Then - fitter.remove_fit_constraint(10) - - # Expect - mock_minimizer.remove_fit_constraint.assert_called_once_with(10) - def test_make_model(self, fitter: Fitter): # When mock_minimizer = MagicMock() @@ -128,8 +92,6 @@ def test_create(self, fitter: Fitter, monkeypatch): def test_switch_minimizer(self, fitter: Fitter, monkeypatch): # When mock_minimizer = MagicMock() - mock_minimizer.fit_constraints = MagicMock(return_value='constraints') - mock_minimizer.set_fit_constraint = MagicMock() fitter._minimizer = mock_minimizer mock_string_to_enum = MagicMock(return_value=10) monkeypatch.setattr(easyscience.fitting.fitter, 'from_string_to_enum', mock_string_to_enum) @@ -139,8 +101,6 @@ def test_switch_minimizer(self, fitter: Fitter, monkeypatch): # Expect fitter._update_minimizer.count(2) - mock_minimizer.set_fit_constraint.assert_called_once_with('constraints') - mock_minimizer.fit_constraints.assert_called_once() mock_string_to_enum.assert_called_once_with('great-minimizer') def test_update_minimizer(self, monkeypatch): diff --git a/tests/unit_tests/global_object/__init__.py b/tests/unit_tests/global_object/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/global_object/test_global_object.py b/tests/unit_tests/global_object/test_global_object.py index 2997b523..6f0c463b 100644 --- a/tests/unit_tests/global_object/test_global_object.py +++ b/tests/unit_tests/global_object/test_global_object.py @@ -1,6 +1,6 @@ import easyscience from easyscience.global_object.global_object import GlobalObject -from easyscience.Objects.variable.descriptor_bool import DescriptorBool +from easyscience.variable import DescriptorBool class TestGlobalObject: def test_init(self): diff --git a/tests/unit_tests/global_object/test_map.py b/tests/unit_tests/global_object/test_map.py index b3eae678..df344aed 100644 --- a/tests/unit_tests/global_object/test_map.py +++ b/tests/unit_tests/global_object/test_map.py @@ -1,10 +1,9 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-License-Identifier: BSD-3-Clause +# © 2025 Contributors to the EasyScience project + diff --git a/tests/unit_tests/utils/io_tests/test_core.py b/tests/unit_tests/io/test_serializer_component.py similarity index 51% rename from tests/unit_tests/utils/io_tests/test_core.py rename to tests/unit_tests/io/test_serializer_component.py index 3e87d539..80994b4d 100644 --- a/tests/unit_tests/utils/io_tests/test_core.py +++ b/tests/unit_tests/io/test_serializer_component.py @@ -1,5 +1,3 @@ -__author__ = "github.com/wardsimon" -__version__ = "0.0.1" import numpy as np from copy import deepcopy @@ -8,9 +6,8 @@ import pytest import easyscience -from easyscience.Objects.ObjectClasses import BaseObj -from easyscience.Objects.variable import DescriptorNumber -from easyscience.Objects.variable import Parameter +from easyscience import DescriptorNumber +from easyscience import Parameter dp_param_dict = { "argnames": "dp_kwargs, dp_cls", @@ -45,7 +42,6 @@ "url": "https://www.boo.com", "description": "", "display_name": "test", - "enabled": True, }, Parameter, ], @@ -99,86 +95,3 @@ def test_variable_as_dict_methods(dp_kwargs: dict, dp_cls: Type[DescriptorNumber check_dict(dp_kwargs, enc) - -@pytest.mark.parametrize(**skip_dict) -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_as_data_dict_methods(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - if isinstance(skip, str): - del data_dict[skip] - - if not isinstance(skip, list): - skip = [skip] - - enc_d = obj.as_data_dict(skip=skip) - - expected_keys = set(data_dict.keys()) - obtained_keys = set(enc_d.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(data_dict, enc_d) - - -class A(BaseObj): - def __init__(self, name: str = "A", **kwargs): - super().__init__(name=name, **kwargs) - - -class B(BaseObj): - def __init__(self, a, b, unique_name): - super(B, self).__init__("B", a=a, unique_name=unique_name) - self.b = b - - -@pytest.mark.parametrize(**dp_param_dict) -def test_custom_class_as_dict_methods(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - a_kw = {data_dict["name"]: dp_cls(**data_dict)} - - full_d = { - "@module": A.__module__, - "@class": A.__name__, - "@version": None, - "name": "A", - dp_kwargs["name"]: dp_kwargs, - } - - obj = A(**a_kw) - - enc = obj.as_dict() - expected_keys = set(full_d.keys()) - obtained_keys = set(enc.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(full_d, enc) - - -@pytest.mark.parametrize(**dp_param_dict) -def test_custom_class_as_data_dict_methods(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - a_kw = {data_dict["name"]: dp_cls(**data_dict)} - - full_d = {"name": "A", dp_kwargs["name"]: data_dict} - - obj = A(**a_kw) - - enc = obj.as_data_dict() - expected_keys = set(full_d.keys()) - obtained_keys = set(enc.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(full_d, enc) diff --git a/tests/unit_tests/io/test_serializer_dict.py b/tests/unit_tests/io/test_serializer_dict.py new file mode 100644 index 00000000..7d75f02d --- /dev/null +++ b/tests/unit_tests/io/test_serializer_dict.py @@ -0,0 +1,117 @@ + +from copy import deepcopy +from typing import Type + +import pytest + +from easyscience.io.serializer_dict import SerializerDict +from easyscience import DescriptorNumber +from easyscience import ObjBase + +from .test_serializer_component import check_dict +from .test_serializer_component import dp_param_dict +from .test_serializer_component import skip_dict +from easyscience import global_object + + +def recursive_remove(d, remove_keys: list) -> dict: + """ + Remove keys from a dictionary. + """ + if not isinstance(remove_keys, list): + remove_keys = [remove_keys] + if isinstance(d, dict): + dd = {} + for k in d.keys(): + if k not in remove_keys: + dd[k] = recursive_remove(d[k], remove_keys) + return dd + else: + return d + + +######################################################################################################################## +# TESTING ENCODING +######################################################################################################################## +@pytest.mark.parametrize(**skip_dict) +@pytest.mark.parametrize(**dp_param_dict) +def test_variable_SerializerDict(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip): + data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + + obj = dp_cls(**data_dict) + + dp_kwargs = deepcopy(dp_kwargs) + + if isinstance(skip, str): + del dp_kwargs[skip] + + if not isinstance(skip, list): + skip = [skip] + + enc = obj.encode(skip=skip, encoder=SerializerDict) + + expected_keys = set(dp_kwargs.keys()) + obtained_keys = set(enc.keys()) + + dif = expected_keys.difference(obtained_keys) + + assert len(dif) == 0 + + check_dict(dp_kwargs, enc) + +######################################################################################################################## +# TESTING DECODING +######################################################################################################################## +@pytest.mark.parametrize(**dp_param_dict) +def test_variable_SerializerDict_decode(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): + data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + + obj = dp_cls(**data_dict) + + enc = obj.encode(encoder=SerializerDict) + global_object.map._clear() + dec = dp_cls.decode(enc, decoder=SerializerDict) + + for k in data_dict.keys(): + if hasattr(obj, k) and hasattr(dec, k): + assert getattr(obj, k) == getattr(dec, k) + else: + raise AttributeError(f"{k} not found in decoded object") + + +@pytest.mark.parametrize(**dp_param_dict) +def test_variable_SerializerDict_from_dict(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): + data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + + obj = dp_cls(**data_dict) + + enc = obj.encode(encoder=SerializerDict) + global_object.map._clear() + dec = dp_cls.from_dict(enc) + + for k in data_dict.keys(): + if hasattr(obj, k) and hasattr(dec, k): + assert getattr(obj, k) == getattr(dec, k) + else: + raise AttributeError(f"{k} not found in decoded object") + +def test_group_encode(): + d0 = DescriptorNumber("a", 0) + d1 = DescriptorNumber("b", 1) + + from easyscience.base_classes import CollectionBase + + b = CollectionBase("test", d0, d1) + d = b.as_dict() + assert isinstance(d["data"], list) + + +def test_group_encode2(): + d0 = DescriptorNumber("a", 0) + d1 = DescriptorNumber("b", 1) + + from easyscience.base_classes import CollectionBase + + b = ObjBase("outer", b=CollectionBase("test", d0, d1)) + d = b.as_dict() + assert isinstance(d["b"], dict) \ No newline at end of file diff --git a/tests/unit_tests/legacy/test_dict.py b/tests/unit_tests/legacy/test_dict.py new file mode 100644 index 00000000..70af9d78 --- /dev/null +++ b/tests/unit_tests/legacy/test_dict.py @@ -0,0 +1,156 @@ + +# from copy import deepcopy +# from typing import Type + +# import pytest + +# from easyscience.io.dict import DataDictSerializer +# from easyscience.io.dict import DictSerializer +# from easyscience.variable import DescriptorNumber +# from easyscience.base_classes import BaseObj + +# from .test_core import check_dict +# from .test_core import dp_param_dict +# from .test_core import skip_dict +# from easyscience import global_object + + +# def recursive_remove(d, remove_keys: list) -> dict: +# """ +# Remove keys from a dictionary. +# """ +# if not isinstance(remove_keys, list): +# remove_keys = [remove_keys] +# if isinstance(d, dict): +# dd = {} +# for k in d.keys(): +# if k not in remove_keys: +# dd[k] = recursive_remove(d[k], remove_keys) +# return dd +# else: +# return d + + +# ######################################################################################################################## +# # TESTING ENCODING +# ######################################################################################################################## +# @pytest.mark.parametrize(**skip_dict) +# @pytest.mark.parametrize(**dp_param_dict) +# def test_variable_DictSerializer(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip): +# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + +# obj = dp_cls(**data_dict) + +# dp_kwargs = deepcopy(dp_kwargs) + +# if isinstance(skip, str): +# del dp_kwargs[skip] + +# if not isinstance(skip, list): +# skip = [skip] + +# enc = obj.encode(skip=skip, encoder=DictSerializer) + +# expected_keys = set(dp_kwargs.keys()) +# obtained_keys = set(enc.keys()) + +# dif = expected_keys.difference(obtained_keys) + +# assert len(dif) == 0 + +# check_dict(dp_kwargs, enc) + + +# @pytest.mark.parametrize(**skip_dict) +# @pytest.mark.parametrize(**dp_param_dict) +# def test_variable_DataDictSerializer(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip): +# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + +# obj = dp_cls(**data_dict) + +# if isinstance(skip, str): +# del data_dict[skip] + +# if not isinstance(skip, list): +# skip = [skip] + +# enc_d = obj.encode(skip=skip, encoder=DataDictSerializer) + +# expected_keys = set(data_dict.keys()) +# obtained_keys = set(enc_d.keys()) + +# dif = expected_keys.difference(obtained_keys) + +# assert len(dif) == 0 + +# check_dict(data_dict, enc_d) + + +# ######################################################################################################################## +# # TESTING DECODING +# ######################################################################################################################## +# @pytest.mark.parametrize(**dp_param_dict) +# def test_variable_DictSerializer_decode(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): +# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + +# obj = dp_cls(**data_dict) + +# enc = obj.encode(encoder=DictSerializer) +# global_object.map._clear() +# dec = dp_cls.decode(enc, decoder=DictSerializer) + +# for k in data_dict.keys(): +# if hasattr(obj, k) and hasattr(dec, k): +# assert getattr(obj, k) == getattr(dec, k) +# else: +# raise AttributeError(f"{k} not found in decoded object") + + +# @pytest.mark.parametrize(**dp_param_dict) +# def test_variable_DictSerializer_from_dict(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): +# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + +# obj = dp_cls(**data_dict) + +# enc = obj.encode(encoder=DictSerializer) +# global_object.map._clear() +# dec = dp_cls.from_dict(enc) + +# for k in data_dict.keys(): +# if hasattr(obj, k) and hasattr(dec, k): +# assert getattr(obj, k) == getattr(dec, k) +# else: +# raise AttributeError(f"{k} not found in decoded object") + + +# @pytest.mark.parametrize(**dp_param_dict) +# def test_variable_DataDictSerializer_decode(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): +# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + +# obj = dp_cls(**data_dict) + +# enc = obj.encode(encoder=DataDictSerializer) +# with pytest.raises(NotImplementedError): +# dec = obj.decode(enc, decoder=DataDictSerializer) + + +# def test_group_encode(): +# d0 = DescriptorNumber("a", 0) +# d1 = DescriptorNumber("b", 1) + +# from easyscience.base_classes import BaseCollection + +# b = BaseCollection("test", d0, d1) +# d = b.as_dict() +# assert isinstance(d["data"], list) + + +# def test_group_encode2(): +# d0 = DescriptorNumber("a", 0) +# d1 = DescriptorNumber("b", 1) + +# from easyscience.base_classes import BaseCollection + +# b = BaseObj("outer", b=BaseCollection("test", d0, d1)) +# d = b.as_dict() +# assert isinstance(d["b"], dict) \ No newline at end of file diff --git a/tests/unit_tests/legacy/test_json.py b/tests/unit_tests/legacy/test_json.py new file mode 100644 index 00000000..651bd950 --- /dev/null +++ b/tests/unit_tests/legacy/test_json.py @@ -0,0 +1,123 @@ + +# import json +# from copy import deepcopy +# from typing import Type + +# import pytest + +# from easyscience.io.json import JsonDataSerializer +# from easyscience.io.json import JsonSerializer +# from easyscience.variable import DescriptorNumber + +# from .test_core import check_dict +# from .test_core import dp_param_dict +# from .test_core import skip_dict +# from easyscience import global_object + + +# def recursive_remove(d, remove_keys: list) -> dict: +# """ +# Remove keys from a dictionary. +# """ +# if not isinstance(remove_keys, list): +# remove_keys = [remove_keys] +# if isinstance(d, dict): +# dd = {} +# for k in d.keys(): +# if k not in remove_keys: +# dd[k] = recursive_remove(d[k], remove_keys) +# return dd +# else: +# return d + + +# ######################################################################################################################## +# # TESTING ENCODING +# ######################################################################################################################## +# @pytest.mark.parametrize(**skip_dict) +# @pytest.mark.parametrize(**dp_param_dict) +# def test_variable_DictSerializer(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip): +# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + +# obj = dp_cls(**data_dict) + +# dp_kwargs = deepcopy(dp_kwargs) + +# if isinstance(skip, str): +# del dp_kwargs[skip] + +# if not isinstance(skip, list): +# skip = [skip] + +# enc = obj.encode(skip=skip, encoder=JsonSerializer) +# assert isinstance(enc, str) + +# # We can test like this as we don't have "complex" objects yet +# dec = json.loads(enc) +# expected_keys = set(dp_kwargs.keys()) +# obtained_keys = set(dec.keys()) + +# dif = expected_keys.difference(obtained_keys) + +# assert len(dif) == 0 + +# check_dict(dp_kwargs, dec) + + +# @pytest.mark.parametrize(**skip_dict) +# @pytest.mark.parametrize(**dp_param_dict) +# def test_variable_DataDictSerializer(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip): +# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + +# obj = dp_cls(**data_dict) + +# if isinstance(skip, str): +# del data_dict[skip] + +# if not isinstance(skip, list): +# skip = [skip] + +# enc = obj.encode(skip=skip, encoder=JsonDataSerializer) +# assert isinstance(enc, str) +# enc_d = json.loads(enc) + +# expected_keys = set(data_dict.keys()) +# obtained_keys = set(enc_d.keys()) + +# dif = expected_keys.difference(obtained_keys) + +# assert len(dif) == 0 + +# check_dict(data_dict, enc_d) + +# # ######################################################################################################################## +# # # TESTING DECODING +# # ######################################################################################################################## +# @pytest.mark.parametrize(**dp_param_dict) +# def test_variable_DictSerializer_decode(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): +# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + +# obj = dp_cls(**data_dict) + +# enc = obj.encode(encoder=JsonSerializer) +# global_object.map._clear() +# assert isinstance(enc, str) +# dec = obj.decode(enc, decoder=JsonSerializer) + +# for k in data_dict.keys(): +# if hasattr(obj, k) and hasattr(dec, k): +# assert getattr(obj, k) == getattr(dec, k) +# else: +# raise AttributeError(f"{k} not found in decoded object") + + +# @pytest.mark.parametrize(**dp_param_dict) +# def test_variable_DataDictSerializer_decode(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): +# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + +# obj = dp_cls(**data_dict) + +# enc = obj.encode(encoder=JsonDataSerializer) +# global_object.map._clear() +# with pytest.raises(NotImplementedError): +# dec = obj.decode(enc, decoder=JsonDataSerializer) diff --git a/tests/unit_tests/legacy/test_xml.py b/tests/unit_tests/legacy/test_xml.py new file mode 100644 index 00000000..7094de85 --- /dev/null +++ b/tests/unit_tests/legacy/test_xml.py @@ -0,0 +1,111 @@ + +# import sys +# import xml.etree.ElementTree as ET +# from copy import deepcopy +# from typing import Type + +# import pytest + +# from easyscience.legacy.xml import XMLSerializer +# from easyscience.variable import DescriptorNumber + +# from ..io.test_core import dp_param_dict +# from ..io.test_core import skip_dict +# from easyscience import global_object + +# def recursive_remove(d, remove_keys: list) -> dict: +# """ +# Remove keys from a dictionary. +# """ +# if not isinstance(remove_keys, list): +# remove_keys = [remove_keys] +# if isinstance(d, dict): +# dd = {} +# for k in d.keys(): +# if k not in remove_keys: +# dd[k] = recursive_remove(d[k], remove_keys) +# return dd +# else: +# return d + + +# def recursive_test(testing_obj, reference_obj): +# for i, (k, v) in enumerate(testing_obj.items()): +# if isinstance(v, dict): +# recursive_test(v, reference_obj[i]) +# else: +# assert v == XMLSerializer.string_to_variable(reference_obj[i].text) + + +# ######################################################################################################################## +# # TESTING ENCODING +# ######################################################################################################################## +# @pytest.mark.parametrize(**skip_dict) +# @pytest.mark.parametrize(**dp_param_dict) +# def test_variable_XMLDictSerializer(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip): +# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + +# obj = dp_cls(**data_dict) + +# dp_kwargs = deepcopy(dp_kwargs) + +# if isinstance(skip, str): +# del dp_kwargs[skip] + +# if not isinstance(skip, list): +# skip = [skip] + +# enc = obj.encode(skip=skip, encoder=XMLSerializer) +# ref_encode = obj.encode(skip=skip) +# assert isinstance(enc, str) +# data_xml = ET.XML(enc) +# assert data_xml.tag == "data" +# recursive_test(data_xml, ref_encode) + +# # ######################################################################################################################## +# # # TESTING DECODING +# # ######################################################################################################################## +# @pytest.mark.parametrize(**dp_param_dict) +# def test_variable_XMLDictSerializer_decode(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): +# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + +# obj = dp_cls(**data_dict) + +# enc = obj.encode(encoder=XMLSerializer) +# assert isinstance(enc, str) +# data_xml = ET.XML(enc) +# assert data_xml.tag == "data" +# global_object.map._clear() +# dec = dp_cls.decode(enc, decoder=XMLSerializer) + +# for k in data_dict.keys(): +# if hasattr(obj, k) and hasattr(dec, k): +# assert getattr(obj, k) == getattr(dec, k) +# else: +# raise AttributeError(f"{k} not found in decoded object") + + +# def test_slow_encode(): + +# if sys.version_info < (3, 9): +# pytest.skip("This test is only for python 3.9+") + +# a = {"a": [1, 2, 3]} +# slow_xml = XMLSerializer().encode(a, fast=False) +# reference = """ +# 1 +# 2 +# 3 +# """ +# assert slow_xml == reference + + +# def test_include_header(): + +# if sys.version_info < (3, 9): +# pytest.skip("This test is only for python 3.9+") + +# a = {"a": [1, 2, 3]} +# header_xml = XMLSerializer().encode(a, use_header=True) +# reference = '?xml version="1.0" encoding="UTF-8"?\n\n 1\n 2\n 3\n' +# assert header_xml == reference diff --git a/tests/unit_tests/models/__init__.py b/tests/unit_tests/models/__init__.py index 3d57f66c..bb769856 100644 --- a/tests/unit_tests/models/__init__.py +++ b/tests/unit_tests/models/__init__.py @@ -1,6 +1,4 @@ -# SPDX-FileCopyrightText: 2022 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2022 Contributors to the EasyScience project +# © 2025 Contributors to the EasyScience project -__author__ = "github.com/wardsimon" -__version__ = "0.0.1" diff --git a/tests/unit_tests/models/test_polynomial.py b/tests/unit_tests/models/test_polynomial.py index adccb9e1..799a917a 100644 --- a/tests/unit_tests/models/test_polynomial.py +++ b/tests/unit_tests/models/test_polynomial.py @@ -1,18 +1,13 @@ -# SPDX-FileCopyrightText: 2022 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2022 Contributors to the EasyScience project +# © 2025 Contributors to the EasyScience project -__author__ = "github.com/wardsimon" -__version__ = "0.0.1" import numpy as np import pytest -from easyscience.models.polynomial import Line from easyscience.models.polynomial import Polynomial -from easyscience.Objects.variable.parameter import Parameter -line_test_cases = ((1, 2), (-1, -2), (0.72, 6.48)) poly_test_cases = ( (1.,), ( @@ -24,33 +19,6 @@ (0.72, 6.48, -0.48), ) - -@pytest.mark.parametrize("m, c", line_test_cases) -def test_Line_pars(m, c): - line = Line(m, c) - - assert line.m.value == m - assert line.c.value == c - - x = np.linspace(0, 10, 100) - y = line.m.value * x + line.c.value - assert np.allclose(line(x), y) - - -@pytest.mark.parametrize("m, c", line_test_cases) -def test_Line_constructor(m, c): - m_ = Parameter("m", m) - c_ = Parameter("c", c) - line = Line(m_, c_) - - assert line.m.value == m - assert line.c.value == c - - x = np.linspace(0, 10, 100) - y = line.m.value * x + line.c.value - assert np.allclose(line(x), y) - - @pytest.mark.parametrize("coo", poly_test_cases) def test_Polynomial_pars(coo): poly = Polynomial(coefficients=coo) diff --git a/tests/unit_tests/utils/__init__.py b/tests/unit_tests/utils/__init__.py deleted file mode 100644 index 3d57f66c..00000000 --- a/tests/unit_tests/utils/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# SPDX-FileCopyrightText: 2022 EasyScience contributors -# SPDX-License-Identifier: BSD-3-Clause -# © 2022 Contributors to the EasyScience project - -__author__ = "github.com/wardsimon" -__version__ = "0.0.1" diff --git a/tests/unit_tests/utils/io_tests/__init__.py b/tests/unit_tests/utils/io_tests/__init__.py deleted file mode 100644 index 3d57f66c..00000000 --- a/tests/unit_tests/utils/io_tests/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# SPDX-FileCopyrightText: 2022 EasyScience contributors -# SPDX-License-Identifier: BSD-3-Clause -# © 2022 Contributors to the EasyScience project - -__author__ = "github.com/wardsimon" -__version__ = "0.0.1" diff --git a/tests/unit_tests/utils/io_tests/test_dict.py b/tests/unit_tests/utils/io_tests/test_dict.py deleted file mode 100644 index 884f86b6..00000000 --- a/tests/unit_tests/utils/io_tests/test_dict.py +++ /dev/null @@ -1,419 +0,0 @@ -__author__ = "github.com/wardsimon" -__version__ = "0.0.1" - -from copy import deepcopy -from typing import Type - -import numpy as np -import pytest -from importlib import metadata - -from easyscience.Utils.io.dict import DataDictSerializer -from easyscience.Utils.io.dict import DictSerializer -from easyscience.Objects.variable import DescriptorNumber -from easyscience.Objects.ObjectClasses import BaseObj - -from .test_core import A -from .test_core import B -from .test_core import check_dict -from .test_core import dp_param_dict -from .test_core import skip_dict -from easyscience import global_object - - -def recursive_remove(d, remove_keys: list) -> dict: - """ - Remove keys from a dictionary. - """ - if not isinstance(remove_keys, list): - remove_keys = [remove_keys] - if isinstance(d, dict): - dd = {} - for k in d.keys(): - if k not in remove_keys: - dd[k] = recursive_remove(d[k], remove_keys) - return dd - else: - return d - - -######################################################################################################################## -# TESTING ENCODING -######################################################################################################################## -@pytest.mark.parametrize(**skip_dict) -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_DictSerializer(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - dp_kwargs = deepcopy(dp_kwargs) - - if isinstance(skip, str): - del dp_kwargs[skip] - - if not isinstance(skip, list): - skip = [skip] - - enc = obj.encode(skip=skip, encoder=DictSerializer) - - expected_keys = set(dp_kwargs.keys()) - obtained_keys = set(enc.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(dp_kwargs, enc) - - -@pytest.mark.parametrize(**skip_dict) -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_DataDictSerializer(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - if isinstance(skip, str): - del data_dict[skip] - - if not isinstance(skip, list): - skip = [skip] - - enc_d = obj.encode(skip=skip, encoder=DataDictSerializer) - - expected_keys = set(data_dict.keys()) - obtained_keys = set(enc_d.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(data_dict, enc_d) - - -@pytest.mark.parametrize( - "encoder", [None, DataDictSerializer], ids=["Default", "DataDictSerializer"] -) -@pytest.mark.parametrize(**skip_dict) -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_encode_data(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip, encoder): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - if isinstance(skip, str): - del data_dict[skip] - - if not isinstance(skip, list): - skip = [skip] - - enc_d = obj.encode_data(skip=skip, encoder=encoder) - - expected_keys = set(data_dict.keys()) - obtained_keys = set(enc_d.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(data_dict, enc_d) - - -@pytest.mark.parametrize(**skip_dict) -@pytest.mark.parametrize(**dp_param_dict) -def test_custom_class_DictSerializer_encode( - dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip -): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - a_kw = {data_dict["name"]: dp_cls(**data_dict)} - - full_d = { - "@module": A.__module__, - "@class": A.__name__, - "@version": None, - "name": "A", - dp_kwargs["name"]: deepcopy(dp_kwargs), - } - - if not isinstance(skip, list): - skip = [skip] - - full_d = recursive_remove(full_d, skip) - - obj = A(**a_kw) - - enc = obj.encode(skip=skip, encoder=DictSerializer) - expected_keys = set(full_d.keys()) - obtained_keys = set(enc.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(full_d, enc) - - -@pytest.mark.parametrize(**skip_dict) -@pytest.mark.parametrize(**dp_param_dict) -def test_custom_class_DataDictSerializer( - dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip -): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - a_kw = {data_dict["name"]: dp_cls(**data_dict)} - - full_d = {"name": "A", dp_kwargs["name"]: data_dict} - - full_d = recursive_remove(full_d, skip) - - obj = A(**a_kw) - - enc = obj.encode(skip=skip, encoder=DataDictSerializer) - expected_keys = set(full_d.keys()) - obtained_keys = set(enc.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(full_d, enc) - - -@pytest.mark.parametrize( - "encoder", [None, DataDictSerializer], ids=["Default", "DataDictSerializer"] -) -@pytest.mark.parametrize(**dp_param_dict) -def test_custom_class_encode_data(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], encoder): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - a_kw = {data_dict["name"]: dp_cls(**data_dict)} - - full_d = {"name": "A", dp_kwargs["name"]: data_dict} - - obj = A(**a_kw) - - enc = obj.encode_data(encoder=encoder) - expected_keys = set(full_d.keys()) - obtained_keys = set(enc.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(full_d, enc) - - -def test_custom_class_full_encode_with_numpy(): - class B(BaseObj): - def __init__(self, a, b, unique_name): - super(B, self).__init__("B", a=a, unique_name=unique_name) - self.b = b - # Same as in __init__.py for easyscience - try: - version = metadata.version('easyscience') # 'easyscience' is the name of the package in 'setup.py - except metadata.PackageNotFoundError: - version = '0.0.0' - - obj = B(DescriptorNumber("a", 1.0, unique_name="a"), np.array([1.0, 2.0, 3.0]), unique_name="B_0") - full_enc = obj.encode(encoder=DictSerializer, full_encode=True) - expected = { - "@module": "tests.unit_tests.utils.io_tests.test_dict", - "@class": "B", - "@version": None, - "unique_name": "B_0", - "b": { - "@module": "numpy", - "@class": "array", - "dtype": "float64", - "data": [1.0, 2.0, 3.0], - }, - "a": { - "@module": "easyscience.Objects.variable.descriptor_number", - "@class": "DescriptorNumber", - "@version": version, - "description": "", - "unit": "dimensionless", - "display_name": "a", - "name": "a", - "value": 1.0, - "variance": None, - "unique_name": "a", - "url": "", - }, - } - check_dict(full_enc, expected) - - -def test_custom_class_full_decode_with_numpy(): - global_object.map._clear() - obj = B(DescriptorNumber("a", 1.0), np.array([1.0, 2.0, 3.0]), unique_name="B_0") - full_enc = obj.encode(encoder=DictSerializer, full_encode=True) - global_object.map._clear() - obj2 = B.decode(full_enc, decoder=DictSerializer) - assert obj.name == obj2.name - assert obj.unique_name == obj2.unique_name - assert obj.a.value == obj2.a.value - assert np.all(obj.b == obj2.b) - - -######################################################################################################################## -# TESTING DECODING -######################################################################################################################## -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_DictSerializer_decode(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - enc = obj.encode(encoder=DictSerializer) - global_object.map._clear() - dec = dp_cls.decode(enc, decoder=DictSerializer) - - for k in data_dict.keys(): - if hasattr(obj, k) and hasattr(dec, k): - assert getattr(obj, k) == getattr(dec, k) - else: - raise AttributeError(f"{k} not found in decoded object") - - -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_DictSerializer_from_dict(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - enc = obj.encode(encoder=DictSerializer) - global_object.map._clear() - dec = dp_cls.from_dict(enc) - - for k in data_dict.keys(): - if hasattr(obj, k) and hasattr(dec, k): - assert getattr(obj, k) == getattr(dec, k) - else: - raise AttributeError(f"{k} not found in decoded object") - - -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_DataDictSerializer_decode(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - enc = obj.encode(encoder=DataDictSerializer) - with pytest.raises(NotImplementedError): - dec = obj.decode(enc, decoder=DataDictSerializer) - - -def test_group_encode(): - d0 = DescriptorNumber("a", 0) - d1 = DescriptorNumber("b", 1) - - from easyscience.Objects.Groups import BaseCollection - - b = BaseCollection("test", d0, d1) - d = b.as_dict() - assert isinstance(d["data"], list) - - -def test_group_encode2(): - d0 = DescriptorNumber("a", 0) - d1 = DescriptorNumber("b", 1) - - from easyscience.Objects.Groups import BaseCollection - - b = BaseObj("outer", b=BaseCollection("test", d0, d1)) - d = b.as_dict() - assert isinstance(d["b"], dict) - - -#TODO: do we need/want this test? -# -# @pytest.mark.parametrize(**dp_param_dict) -# def test_custom_class_DictSerializer_decode(dp_kwargs: dict, dp_cls: Type[Descriptor]): -# -# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != '@'} -# -# a_kw = { -# data_dict['name']: dp_cls(**data_dict) -# } -# -# obj = A(**a_kw) -# -# enc = obj.encode(encoder=DictSerializer) -# -# stripped_encode = {k: v for k, v in enc.items() if k[0] != '@'} -# stripped_encode[data_dict['name']] = data_dict -# -# dec = obj.decode(enc, decoder=DictSerializer) -# -# def test_objs(reference_obj, test_obj, in_dict): -# if 'value' in in_dict.keys(): -# in_dict['value'] = in_dict.pop('value') -# if 'units' in in_dict.keys(): -# del in_dict['units'] -# for k in in_dict.keys(): -# if hasattr(reference_obj, k) and hasattr(test_obj, k): -# if isinstance(in_dict[k], dict): -# test_objs(getattr(obj, k), getattr(test_obj, k), in_dict[k]) -# assert getattr(obj, k) == getattr(dec, k) -# else: -# raise AttributeError(f"{k} not found in decoded object") -# test_objs(obj, dec, stripped_encode) -# -# -# @pytest.mark.parametrize(**skip_dict) -# @pytest.mark.parametrize(**dp_param_dict) -# def test_custom_class_DataDictSerializer(dp_kwargs: dict, dp_cls: Type[Descriptor], skip): -# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != '@'} -# -# a_kw = { -# data_dict['name']: dp_cls(**data_dict) -# } -# -# full_d = { -# "name": "A", -# dp_kwargs['name']: data_dict -# } -# -# full_d = recursive_remove(full_d, skip) -# -# obj = A(**a_kw) -# -# enc = obj.encode(skip=skip, encoder=DataDictSerializer) -# expected_keys = set(full_d.keys()) -# obtained_keys = set(enc.keys()) -# -# dif = expected_keys.difference(obtained_keys) -# -# assert len(dif) == 0 -# -# check_dict(full_d, enc) -# -# -# @pytest.mark.parametrize('encoder', [None, DataDictSerializer], ids=['Default', 'DataDictSerializer']) -# @pytest.mark.parametrize(**dp_param_dict) -# def test_custom_class_encode_data(dp_kwargs: dict, dp_cls: Type[Descriptor], encoder): -# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != '@'} -# -# a_kw = { -# data_dict['name']: dp_cls(**data_dict) -# } -# -# full_d = { -# "name": "A", -# dp_kwargs['name']: data_dict -# } -# -# obj = A(**a_kw) -# -# enc = obj.encode_data(encoder=encoder) -# expected_keys = set(full_d.keys()) -# obtained_keys = set(enc.keys()) -# -# dif = expected_keys.difference(obtained_keys) -# -# assert len(dif) == 0 -# -# check_dict(full_d, enc) diff --git a/tests/unit_tests/utils/io_tests/test_json.py b/tests/unit_tests/utils/io_tests/test_json.py deleted file mode 100644 index cec6e4c0..00000000 --- a/tests/unit_tests/utils/io_tests/test_json.py +++ /dev/null @@ -1,198 +0,0 @@ -__author__ = "github.com/wardsimon" -__version__ = "0.0.1" - -import json -from copy import deepcopy -from typing import Type - -import pytest - -from easyscience.Utils.io.json import JsonDataSerializer -from easyscience.Utils.io.json import JsonSerializer -from easyscience.Objects.variable import DescriptorNumber - -from .test_core import A -from .test_core import check_dict -from .test_core import dp_param_dict -from .test_core import skip_dict -from easyscience import global_object - - -def recursive_remove(d, remove_keys: list) -> dict: - """ - Remove keys from a dictionary. - """ - if not isinstance(remove_keys, list): - remove_keys = [remove_keys] - if isinstance(d, dict): - dd = {} - for k in d.keys(): - if k not in remove_keys: - dd[k] = recursive_remove(d[k], remove_keys) - return dd - else: - return d - - -######################################################################################################################## -# TESTING ENCODING -######################################################################################################################## -@pytest.mark.parametrize(**skip_dict) -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_DictSerializer(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - dp_kwargs = deepcopy(dp_kwargs) - - if isinstance(skip, str): - del dp_kwargs[skip] - - if not isinstance(skip, list): - skip = [skip] - - enc = obj.encode(skip=skip, encoder=JsonSerializer) - assert isinstance(enc, str) - - # We can test like this as we don't have "complex" objects yet - dec = json.loads(enc) - expected_keys = set(dp_kwargs.keys()) - obtained_keys = set(dec.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(dp_kwargs, dec) - - -@pytest.mark.parametrize(**skip_dict) -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_DataDictSerializer(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - if isinstance(skip, str): - del data_dict[skip] - - if not isinstance(skip, list): - skip = [skip] - - enc = obj.encode(skip=skip, encoder=JsonDataSerializer) - assert isinstance(enc, str) - enc_d = json.loads(enc) - - expected_keys = set(data_dict.keys()) - obtained_keys = set(enc_d.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(data_dict, enc_d) - - -@pytest.mark.parametrize(**skip_dict) -@pytest.mark.parametrize(**dp_param_dict) -def test_custom_class_DictSerializer_encode( - dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip -): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - a_kw = {data_dict["name"]: dp_cls(**data_dict)} - - full_d = { - "@module": A.__module__, - "@class": A.__name__, - "@version": None, - "name": "A", - dp_kwargs["name"]: deepcopy(dp_kwargs), - } - - if not isinstance(skip, list): - skip = [skip] - - full_d = recursive_remove(full_d, skip) - - obj = A(**a_kw) - - enc = obj.encode(skip=skip, encoder=JsonSerializer) - assert isinstance(enc, str) - - # We can test like this as we don't have "complex" objects yet - dec = json.loads(enc) - - expected_keys = set(full_d.keys()) - obtained_keys = set(dec.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(full_d, dec) - - -@pytest.mark.parametrize(**skip_dict) -@pytest.mark.parametrize(**dp_param_dict) -def test_custom_class_DataDictSerializer( - dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip -): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - a_kw = {data_dict["name"]: dp_cls(**data_dict)} - - full_d = {"name": "A", dp_kwargs["name"]: data_dict} - - if not isinstance(skip, list): - skip = [skip] - - full_d = recursive_remove(full_d, skip) - - obj = A(**a_kw) - - enc = obj.encode(skip=skip, encoder=JsonDataSerializer) - dec = json.loads(enc) - - expected_keys = set(full_d.keys()) - obtained_keys = set(dec.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(full_d, dec) - - -# ######################################################################################################################## -# # TESTING DECODING -# ######################################################################################################################## -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_DictSerializer_decode(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - enc = obj.encode(encoder=JsonSerializer) - global_object.map._clear() - assert isinstance(enc, str) - dec = obj.decode(enc, decoder=JsonSerializer) - - for k in data_dict.keys(): - if hasattr(obj, k) and hasattr(dec, k): - assert getattr(obj, k) == getattr(dec, k) - else: - raise AttributeError(f"{k} not found in decoded object") - - -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_DataDictSerializer_decode(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - enc = obj.encode(encoder=JsonDataSerializer) - global_object.map._clear() - with pytest.raises(NotImplementedError): - dec = obj.decode(enc, decoder=JsonDataSerializer) diff --git a/tests/unit_tests/utils/io_tests/test_xml.py b/tests/unit_tests/utils/io_tests/test_xml.py deleted file mode 100644 index 2edb761e..00000000 --- a/tests/unit_tests/utils/io_tests/test_xml.py +++ /dev/null @@ -1,147 +0,0 @@ -__author__ = "github.com/wardsimon" -__version__ = "0.0.1" - -import sys -import xml.etree.ElementTree as ET -from copy import deepcopy -from typing import Type - -import pytest - -from easyscience.Utils.io.xml import XMLSerializer -from easyscience.Objects.variable import DescriptorNumber - -from .test_core import A -from .test_core import dp_param_dict -from .test_core import skip_dict -from easyscience import global_object - -def recursive_remove(d, remove_keys: list) -> dict: - """ - Remove keys from a dictionary. - """ - if not isinstance(remove_keys, list): - remove_keys = [remove_keys] - if isinstance(d, dict): - dd = {} - for k in d.keys(): - if k not in remove_keys: - dd[k] = recursive_remove(d[k], remove_keys) - return dd - else: - return d - - -def recursive_test(testing_obj, reference_obj): - for i, (k, v) in enumerate(testing_obj.items()): - if isinstance(v, dict): - recursive_test(v, reference_obj[i]) - else: - assert v == XMLSerializer.string_to_variable(reference_obj[i].text) - - -######################################################################################################################## -# TESTING ENCODING -######################################################################################################################## -@pytest.mark.parametrize(**skip_dict) -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_XMLDictSerializer(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - dp_kwargs = deepcopy(dp_kwargs) - - if isinstance(skip, str): - del dp_kwargs[skip] - - if not isinstance(skip, list): - skip = [skip] - - enc = obj.encode(skip=skip, encoder=XMLSerializer) - ref_encode = obj.encode(skip=skip) - assert isinstance(enc, str) - data_xml = ET.XML(enc) - assert data_xml.tag == "data" - recursive_test(data_xml, ref_encode) - - -@pytest.mark.parametrize(**skip_dict) -@pytest.mark.parametrize(**dp_param_dict) -def test_custom_class_XMLDictSerializer_encode( - dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip -): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - a_kw = {data_dict["name"]: dp_cls(**data_dict)} - - full_d = { - "@module": A.__module__, - "@class": A.__name__, - "@version": None, - "name": "A", - dp_kwargs["name"]: deepcopy(dp_kwargs), - } - - if not isinstance(skip, list): - skip = [skip] - - full_d = recursive_remove(full_d, skip) - - obj = A(**a_kw) - - enc = obj.encode(skip=skip, encoder=XMLSerializer) - ref_encode = obj.encode(skip=skip) - assert isinstance(enc, str) - data_xml = ET.XML(enc) - assert data_xml.tag == "data" - recursive_test(data_xml, ref_encode) - - -# ######################################################################################################################## -# # TESTING DECODING -# ######################################################################################################################## -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_XMLDictSerializer_decode(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - enc = obj.encode(encoder=XMLSerializer) - assert isinstance(enc, str) - data_xml = ET.XML(enc) - assert data_xml.tag == "data" - global_object.map._clear() - dec = dp_cls.decode(enc, decoder=XMLSerializer) - - for k in data_dict.keys(): - if hasattr(obj, k) and hasattr(dec, k): - assert getattr(obj, k) == getattr(dec, k) - else: - raise AttributeError(f"{k} not found in decoded object") - - -def test_slow_encode(): - - if sys.version_info < (3, 9): - pytest.skip("This test is only for python 3.9+") - - a = {"a": [1, 2, 3]} - slow_xml = XMLSerializer().encode(a, fast=False) - reference = """ - 1 - 2 - 3 -""" - assert slow_xml == reference - - -def test_include_header(): - - if sys.version_info < (3, 9): - pytest.skip("This test is only for python 3.9+") - - a = {"a": [1, 2, 3]} - header_xml = XMLSerializer().encode(a, use_header=True) - reference = '?xml version="1.0" encoding="UTF-8"?\n\n 1\n 2\n 3\n' - assert header_xml == reference diff --git a/tests/unit_tests/variable/__init__.py b/tests/unit_tests/variable/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/Objects/variable/test_descriptor_any_type.py b/tests/unit_tests/variable/test_descriptor_any_type.py similarity index 79% rename from tests/unit_tests/Objects/variable/test_descriptor_any_type.py rename to tests/unit_tests/variable/test_descriptor_any_type.py index 5b8a131b..70c4cc65 100644 --- a/tests/unit_tests/Objects/variable/test_descriptor_any_type.py +++ b/tests/unit_tests/variable/test_descriptor_any_type.py @@ -1,7 +1,7 @@ import pytest import numpy as np -from easyscience.Objects.variable.descriptor_any_type import DescriptorAnyType +from easyscience.variable import DescriptorAnyType from easyscience import global_object class TestDescriptorAnyType: @@ -75,18 +75,4 @@ def test_copy(self, descriptor: DescriptorAnyType): # Expect assert type(descriptor_copy) == DescriptorAnyType - assert descriptor_copy._value == descriptor._value - - def test_as_data_dict(self, clear, descriptor: DescriptorAnyType): - # When Then - descriptor_dict = descriptor.as_data_dict() - - # Expect - assert descriptor_dict == { - "name": "name", - "value": "string", - "description": "description", - "url": "url", - "display_name": "display_name", - "unique_name": "DescriptorAnyType_0" - } \ No newline at end of file + assert descriptor_copy._value == descriptor._value \ No newline at end of file diff --git a/tests/unit_tests/Objects/variable/test_descriptor_array.py b/tests/unit_tests/variable/test_descriptor_array.py similarity index 97% rename from tests/unit_tests/Objects/variable/test_descriptor_array.py rename to tests/unit_tests/variable/test_descriptor_array.py index 2708f4ed..695f5fe2 100644 --- a/tests/unit_tests/Objects/variable/test_descriptor_array.py +++ b/tests/unit_tests/variable/test_descriptor_array.py @@ -6,8 +6,8 @@ import numpy as np -from easyscience.Objects.variable.descriptor_array import DescriptorArray -from easyscience.Objects.variable.descriptor_number import DescriptorNumber +from easyscience.variable import DescriptorArray +from easyscience import DescriptorNumber from easyscience import global_object class TestDescriptorArray: @@ -218,32 +218,6 @@ def test_copy(self, descriptor: DescriptorArray): assert type(descriptor_copy) == DescriptorArray assert np.array_equal(descriptor_copy._array.values, descriptor._array.values) assert descriptor_copy._array.unit == descriptor._array.unit - - def test_as_data_dict(self, clear, descriptor: DescriptorArray): - # When - descriptor_dict = descriptor.as_data_dict() - - # Expected dictionary - expected_dict = { - "name": "name", - "value": np.array([[1.0, 2.0], [3.0, 4.0]]), # Use numpy array for comparison - "unit": "m", - "variance": np.array([[0.1, 0.2], [0.3, 0.4]]), # Use numpy array for comparison - "description": "description", - "url": "url", - "display_name": "display_name", - "unique_name": "DescriptorArray_0", - "dimensions": np.array(['dim0', 'dim1']), # Use numpy array for comparison - } - - # Then: Compare dictionaries key by key - for key, expected_value in expected_dict.items(): - if isinstance(expected_value, np.ndarray): - # Compare numpy arrays - assert np.array_equal(descriptor_dict[key], expected_value), f"Mismatch for key: {key}" - else: - # Compare other values directly - assert descriptor_dict[key] == expected_value, f"Mismatch for key: {key}" @pytest.mark.parametrize("unit_string, expected", [ ("1e+9", "dimensionless"), diff --git a/tests/unit_tests/Objects/variable/test_descriptor_base.py b/tests/unit_tests/variable/test_descriptor_base.py similarity index 93% rename from tests/unit_tests/Objects/variable/test_descriptor_base.py rename to tests/unit_tests/variable/test_descriptor_base.py index 38140a34..aeeb823e 100644 --- a/tests/unit_tests/Objects/variable/test_descriptor_base.py +++ b/tests/unit_tests/variable/test_descriptor_base.py @@ -1,7 +1,7 @@ import pytest from easyscience import global_object -from easyscience.Objects.variable.descriptor_base import DescriptorBase +from easyscience.variable.descriptor_base import DescriptorBase class TestDesciptorBase: @@ -140,19 +140,6 @@ def test_copy(self, descriptor: DescriptorBase): assert descriptor_copy._url == descriptor._url assert descriptor_copy._display_name == descriptor._display_name - def test_as_data_dict(self, clear, descriptor: DescriptorBase): - # When Then - descriptor_dict = descriptor.as_data_dict() - - # Expect - assert descriptor_dict == { - "name": "name", - "description": "description", - "url": "url", - "display_name": "display_name", - "unique_name": "DescriptorBase_0", - } - def test_unique_name_generator(self, clear, descriptor: DescriptorBase): # When second_descriptor = DescriptorBase(name="test", unique_name="DescriptorBase_2") diff --git a/tests/unit_tests/Objects/variable/test_descriptor_bool.py b/tests/unit_tests/variable/test_descriptor_bool.py similarity index 81% rename from tests/unit_tests/Objects/variable/test_descriptor_bool.py rename to tests/unit_tests/variable/test_descriptor_bool.py index 63bf484a..4be20657 100644 --- a/tests/unit_tests/Objects/variable/test_descriptor_bool.py +++ b/tests/unit_tests/variable/test_descriptor_bool.py @@ -1,6 +1,6 @@ import pytest -from easyscience.Objects.variable.descriptor_bool import DescriptorBool +from easyscience.variable import DescriptorBool from easyscience import global_object class TestDescriptorBool: @@ -74,18 +74,4 @@ def test_copy(self, descriptor: DescriptorBool): # Expect assert type(descriptor_copy) == DescriptorBool - assert descriptor_copy._bool_value == descriptor._bool_value - - def test_as_data_dict(self, clear, descriptor: DescriptorBool): - # When Then - descriptor_dict = descriptor.as_data_dict() - - # Expect - assert descriptor_dict == { - "name": "name", - "value": True, - "description": "description", - "url": "url", - "display_name": "display_name", - "unique_name": "DescriptorBool_0" - } \ No newline at end of file + assert descriptor_copy._bool_value == descriptor._bool_value \ No newline at end of file diff --git a/tests/unit_tests/Objects/variable/test_descriptor_number.py b/tests/unit_tests/variable/test_descriptor_number.py similarity index 96% rename from tests/unit_tests/Objects/variable/test_descriptor_number.py rename to tests/unit_tests/variable/test_descriptor_number.py index 62b359a4..5dc4e060 100644 --- a/tests/unit_tests/Objects/variable/test_descriptor_number.py +++ b/tests/unit_tests/variable/test_descriptor_number.py @@ -3,7 +3,7 @@ import scipp as sc from scipp import UnitError -from easyscience.Objects.variable.descriptor_number import DescriptorNumber +from easyscience import DescriptorNumber from easyscience import global_object class TestDescriptorNumber: @@ -30,6 +30,7 @@ def test_init(self, descriptor: DescriptorNumber): assert descriptor._scalar.value == 1 assert descriptor._scalar.unit == "m" assert descriptor._scalar.variance == 0.1 + assert descriptor._observers == [] # From super assert descriptor._name == "name" @@ -199,22 +200,6 @@ def test_copy(self, descriptor: DescriptorNumber): assert descriptor_copy._scalar.value == descriptor._scalar.value assert descriptor_copy._scalar.unit == descriptor._scalar.unit - def test_as_data_dict(self, clear, descriptor: DescriptorNumber): - # When Then - descriptor_dict = descriptor.as_data_dict() - - # Expect - assert descriptor_dict == { - "name": "name", - "value": 1.0, - "unit": "m", - "variance": 0.1, - "description": "description", - "url": "url", - "display_name": "display_name", - "unique_name": "DescriptorNumber_0", - } - @pytest.mark.parametrize("unit_string, expected", [ ("1e+9", "dimensionless"), ("1000", "dimensionless"), diff --git a/tests/unit_tests/Objects/variable/test_descriptor_str.py b/tests/unit_tests/variable/test_descriptor_str.py similarity index 79% rename from tests/unit_tests/Objects/variable/test_descriptor_str.py rename to tests/unit_tests/variable/test_descriptor_str.py index 71c50715..aa593ed9 100644 --- a/tests/unit_tests/Objects/variable/test_descriptor_str.py +++ b/tests/unit_tests/variable/test_descriptor_str.py @@ -1,6 +1,6 @@ import pytest -from easyscience.Objects.variable.descriptor_str import DescriptorStr +from easyscience.variable import DescriptorStr from easyscience import global_object class TestDescriptorStr: @@ -73,18 +73,4 @@ def test_copy(self, descriptor: DescriptorStr): # Expect assert type(descriptor_copy) == DescriptorStr - assert descriptor_copy._string == descriptor._string - - def test_as_data_dict(self, clear, descriptor: DescriptorStr): - # When Then - descriptor_dict = descriptor.as_data_dict() - - # Expect - assert descriptor_dict == { - "name": "name", - "value": "string", - "description": "description", - "url": "url", - "display_name": "display_name", - "unique_name": "DescriptorStr_0" - } \ No newline at end of file + assert descriptor_copy._string == descriptor._string \ No newline at end of file diff --git a/tests/unit_tests/Objects/variable/test_parameter.py b/tests/unit_tests/variable/test_parameter.py similarity index 64% rename from tests/unit_tests/Objects/variable/test_parameter.py rename to tests/unit_tests/variable/test_parameter.py index e00350b6..356ed83d 100644 --- a/tests/unit_tests/Objects/variable/test_parameter.py +++ b/tests/unit_tests/variable/test_parameter.py @@ -5,9 +5,10 @@ from scipp import UnitError -from easyscience.Objects.variable.parameter import Parameter -from easyscience.Objects.variable.descriptor_number import DescriptorNumber +from easyscience import Parameter +from easyscience import DescriptorNumber from easyscience import global_object +from easyscience import ObjBase class TestParameter: @pytest.fixture @@ -24,15 +25,35 @@ def parameter(self) -> Parameter: url="url", display_name="display_name", callback=self.mock_callback, - enabled="enabled", parent=None, ) return parameter + + @pytest.fixture + def normal_parameter(self) -> Parameter: + parameter = Parameter( + name="name", + value=1, + unit="m", + variance=0.01, + min=0, + max=10, + ) + return parameter @pytest.fixture def clear(self): global_object.map._clear() + def compare_parameters(self, parameter1: Parameter, parameter2: Parameter): + assert parameter1.value == parameter2.value + assert parameter1.unit == parameter2.unit + assert parameter1.variance == parameter2.variance + assert parameter1.min == parameter2.min + assert parameter1.max == parameter2.max + assert parameter1._min.unit == parameter2._min.unit + assert parameter1._max.unit == parameter2._max.unit + def test_init(self, parameter: Parameter): # When Then Expect assert parameter._min.value == 0 @@ -40,7 +61,7 @@ def test_init(self, parameter: Parameter): assert parameter._max.value == 10 assert parameter._max.unit == "m" assert parameter._callback == self.mock_callback - assert parameter._enabled == "enabled" + assert parameter._independent == True # From super assert parameter._scalar.value == 1 @@ -50,6 +71,8 @@ def test_init(self, parameter: Parameter): assert parameter._description == "description" assert parameter._url == "url" assert parameter._display_name == "display_name" + assert parameter._fixed == False + assert parameter._observers == [] def test_init_value_min_exception(self): # When @@ -69,7 +92,6 @@ def test_init_value_min_exception(self): url="url", display_name="display_name", callback=mock_callback, - enabled="enabled", parent=None, ) @@ -91,10 +113,364 @@ def test_init_value_max_exception(self): url="url", display_name="display_name", callback=mock_callback, - enabled="enabled", parent=None, ) + def test_make_dependent_on(self, normal_parameter: Parameter): + # When + independent_parameter = Parameter(name="independent", value=1, unit="m", variance=0.01, min=0, max=10) + + # Then + normal_parameter.make_dependent_on(dependency_expression='2*a', dependency_map={'a': independent_parameter}) + + # Expect + assert normal_parameter._independent == False + assert normal_parameter.dependency_expression == '2*a' + assert normal_parameter.dependency_map == {'a': independent_parameter} + self.compare_parameters(normal_parameter, 2*independent_parameter) + + # Then + independent_parameter.value = 2 + + # Expect + normal_parameter.value == 4 + self.compare_parameters(normal_parameter, 2*independent_parameter) + + def test_parameter_from_dependency(self, normal_parameter: Parameter): + # When Then + dependent_parameter = Parameter.from_dependency( + name = 'dependent', + dependency_expression='2*a', + dependency_map={'a': normal_parameter}, + display_name='display_name', + ) + + # Expect + assert dependent_parameter._independent == False + assert dependent_parameter.dependency_expression == '2*a' + assert dependent_parameter.dependency_map == {'a': normal_parameter} + assert dependent_parameter.name == 'dependent' + assert dependent_parameter.display_name == 'display_name' + self.compare_parameters(dependent_parameter, 2*normal_parameter) + + # Then + normal_parameter.value = 2 + + # Expect + self.compare_parameters(dependent_parameter, 2*normal_parameter) + + def test_dependent_parameter_with_unique_name(self, clear, normal_parameter: Parameter): + # When Then + dependent_parameter = Parameter.from_dependency( + name = 'dependent', + dependency_expression='2*"Parameter_0"', + ) + + # Expect + assert dependent_parameter.dependency_expression == '2*"Parameter_0"' + assert dependent_parameter.dependency_map == {'__Parameter_0__': normal_parameter} + self.compare_parameters(dependent_parameter, 2*normal_parameter) + + # Then + normal_parameter.value = 2 + + # Expect + self.compare_parameters(dependent_parameter, 2*normal_parameter) + + def test_process_dependency_unique_names_double_quotes(self, clear, normal_parameter: Parameter): + # When + independent_parameter = Parameter(name="independent", value=1, unit="m", variance=0.01, min=0, max=10, unique_name='Special_name') + normal_parameter._dependency_map = {} + + # Then + normal_parameter._process_dependency_unique_names(dependency_expression='2*"Special_name"') + + # Expect + assert normal_parameter._dependency_map == {'__Special_name__': independent_parameter} + assert normal_parameter._clean_dependency_string == '2*__Special_name__' + + def test_process_dependency_unique_names_single_quotes(self, clear, normal_parameter: Parameter): + # When + independent_parameter = Parameter(name="independent", value=1, unit="m", variance=0.01, min=0, max=10, unique_name='Special_name') + independent_parameter_2 = Parameter(name="independent_2", value=1, unit="m", variance=0.01, min=0, max=10, unique_name='Special_name_2') + normal_parameter._dependency_map = {} + + # Then + normal_parameter._process_dependency_unique_names(dependency_expression="'Special_name' + 'Special_name_2'") + + # Expect + assert normal_parameter._dependency_map == {'__Special_name__': independent_parameter, + '__Special_name_2__': independent_parameter_2} + assert normal_parameter._clean_dependency_string == '__Special_name__ + __Special_name_2__' + + def test_process_dependency_unique_names_exception_unique_name_does_not_exist(self, clear, normal_parameter: Parameter): + # When + normal_parameter._dependency_map = {} + + # Then Expect + with pytest.raises(ValueError, match='A Parameter with unique_name Special_name does not exist. Please check your dependency expression.'): + normal_parameter._process_dependency_unique_names(dependency_expression='2*"Special_name"') + + def test_process_dependency_unique_names_exception_not_a_descriptorNumber(self, clear, normal_parameter: Parameter): + # When + normal_parameter._dependency_map = {} + base_obj = ObjBase(name='ObjBase', unique_name='base_obj') + + # Then Expect + with pytest.raises(ValueError, match='The object with unique_name base_obj is not a Parameter or DescriptorNumber. Please check your dependency expression.'): + normal_parameter._process_dependency_unique_names(dependency_expression='2*"base_obj"') + + @pytest.mark.parametrize("dependency_expression, dependency_map", [ + (2, {'a': Parameter(name='a', value=1)}), + ('2*a', ['a', Parameter(name='a', value=1)]), + ('2*a', {4: Parameter(name='a', value=1)}), + ('2*a', {'a': ObjBase(name='a')}), + ], ids=["dependency_expression_not_a_string", "dependency_map_not_a_dict", "dependency_map_keys_not_strings", "dependency_map_values_not_descriptor_number"]) + def test_parameter_from_dependency_input_exceptions(self, dependency_expression, dependency_map): + # When Then Expect + with pytest.raises(TypeError): + Parameter.from_dependency( + name = 'dependent', + dependency_expression=dependency_expression, + dependency_map=dependency_map, + ) + + @pytest.mark.parametrize("dependency_expression, error", [ + ('2*a + b', NameError), + ('2*a + 3*', SyntaxError), + ('2 + 2', TypeError), + ('2*"special_name"', ValueError), + ], ids=["parameter_not_in_map", "invalid_dependency_expression", "result_not_a_descriptor_number", "unique_name_does_not_exist"]) + def test_parameter_make_dependent_on_exceptions_cleanup_previously_dependent(self, normal_parameter, dependency_expression, error): + # When + independent_parameter = Parameter(name='independent', value=10, unit='s', variance=0.02) + dependent_parameter = Parameter.from_dependency( + name= 'dependent', + dependency_expression='best', + dependency_map={'best': independent_parameter} + ) + # Then Expect + # Check that the correct error is raised + with pytest.raises(error): + dependent_parameter.make_dependent_on( + dependency_expression=dependency_expression, + dependency_map={'a': normal_parameter}, + ) + # Check that everything is properly cleaned up + assert normal_parameter._observers == [] + assert dependent_parameter.independent == False + assert dependent_parameter.dependency_expression == 'best' + assert dependent_parameter.dependency_map == {'best': independent_parameter} + independent_parameter.value = 50 + self.compare_parameters(dependent_parameter, independent_parameter) + + def test_parameter_make_dependent_on_exceptions_cleanup_previously_independent(self, normal_parameter): + # When + independent_parameter = Parameter(name='independent', value=10, unit='s', variance=0.02) + # Then Expect + # Check that the correct error is raised + with pytest.raises(NameError): + independent_parameter.make_dependent_on( + dependency_expression='2*a + b', + dependency_map={'a': normal_parameter}, + ) + # Check that everything is properly cleaned up + assert normal_parameter._observers == [] + assert independent_parameter.independent == True + normal_parameter.value = 50 + assert independent_parameter.value == 10 + + def test_dependent_parameter_updates(self, normal_parameter: Parameter): + # When + dependent_parameter = Parameter.from_dependency( + name = 'dependent', + dependency_expression='2*a', + dependency_map={'a': normal_parameter}, + ) + + # Then Expect + normal_parameter.value = 2 + self.compare_parameters(dependent_parameter, 2*normal_parameter) + + normal_parameter.variance = 0.02 + self.compare_parameters(dependent_parameter, 2*normal_parameter) + + normal_parameter.error = 0.2 + self.compare_parameters(dependent_parameter, 2*normal_parameter) + + normal_parameter.convert_unit("cm") + self.compare_parameters(dependent_parameter, 2*normal_parameter) + + normal_parameter.min = 1 + self.compare_parameters(dependent_parameter, 2*normal_parameter) + + normal_parameter.max = 300 + self.compare_parameters(dependent_parameter, 2*normal_parameter) + + def test_dependent_parameter_indirect_updates(self, normal_parameter: Parameter): + # When + dependent_parameter = Parameter.from_dependency( + name = 'dependent', + dependency_expression='2*a', + dependency_map={'a': normal_parameter}, + ) + dependent_parameter_2 = Parameter.from_dependency( + name = 'dependent_2', + dependency_expression='10*a', + dependency_map={'a': normal_parameter}, + ) + dependent_parameter_3 = Parameter.from_dependency( + name = 'dependent_3', + dependency_expression='b+c', + dependency_map={'b': dependent_parameter, 'c': dependent_parameter_2}, + ) + # Then + normal_parameter.value = 2 + + # Expect + self.compare_parameters(dependent_parameter, 2*normal_parameter) + self.compare_parameters(dependent_parameter_2, 10*normal_parameter) + self.compare_parameters(dependent_parameter_3, 2*normal_parameter + 10*normal_parameter) + + def test_dependent_parameter_cyclic_dependencies(self, normal_parameter: Parameter): + # When + dependent_parameter = Parameter.from_dependency( + name = 'dependent', + dependency_expression='2*a', + dependency_map={'a': normal_parameter}, + ) + dependent_parameter_2 = Parameter.from_dependency( + name = 'dependent_2', + dependency_expression='2*b', + dependency_map={'b': dependent_parameter}, + ) + + # Then Expect + with pytest.raises(RuntimeError): + normal_parameter.make_dependent_on(dependency_expression='2*c', dependency_map={'c': dependent_parameter_2}) + # Check that everything is properly cleaned up + assert dependent_parameter_2._observers == [] + assert normal_parameter.independent == True + assert normal_parameter.value == 1 + normal_parameter.value = 50 + self.compare_parameters(dependent_parameter_2, 4*normal_parameter) + + def test_dependent_parameter_logical_dependency(self, normal_parameter: Parameter): + # When + dependent_parameter = Parameter.from_dependency( + name = 'dependent', + dependency_expression='a if a.value > 0 else -a', + dependency_map={'a': normal_parameter}, + ) + self.compare_parameters(dependent_parameter, normal_parameter) + + # Then + normal_parameter.value = -2 + + # Expect + self.compare_parameters(dependent_parameter, -normal_parameter) + + def test_dependent_parameter_return_is_descriptor_number(self): + # When + descriptor_number = DescriptorNumber(name='descriptor', value=1, unit='m', variance=0.01) + + # Then + dependent_parameter = Parameter.from_dependency( + name = 'dependent', + dependency_expression='2*descriptor', + dependency_map={'descriptor': descriptor_number}, + ) + + # Expect + assert dependent_parameter.value == 2*descriptor_number.value + assert dependent_parameter.unit == descriptor_number.unit + assert dependent_parameter.variance == 0.04 + assert dependent_parameter.min == 2*descriptor_number.value + assert dependent_parameter.max == 2*descriptor_number.value + + def test_dependent_parameter_overwrite_dependency(self, normal_parameter: Parameter): + # When + dependent_parameter = Parameter.from_dependency( + name = 'dependent', + dependency_expression='2*a', + dependency_map={'a': normal_parameter}, + ) + self.compare_parameters(dependent_parameter, 2*normal_parameter) + + # Then + normal_parameter_2 = Parameter(name='a2', value=-2, unit='m', variance=0.01, min=-10, max=0) + dependent_parameter.make_dependent_on(dependency_expression='3*a2', dependency_map={'a2': normal_parameter_2}) + normal_parameter.value = 3 + + # Expect + self.compare_parameters(dependent_parameter, 3*normal_parameter_2) + assert dependent_parameter.dependency_expression == '3*a2' + assert dependent_parameter.dependency_map == {'a2': normal_parameter_2} + assert normal_parameter._observers == [] + + def test_make_independent(self, normal_parameter: Parameter): + # When + dependent_parameter = Parameter.from_dependency( + name = 'dependent', + dependency_expression='2*a', + dependency_map={'a': normal_parameter}, + ) + assert dependent_parameter.independent == False + self.compare_parameters(dependent_parameter, 2*normal_parameter) + + # Then + dependent_parameter.make_independent() + normal_parameter.value = 5 + + # Expect + assert dependent_parameter.independent == True + assert normal_parameter._observers == [] + assert dependent_parameter.value == 2 + + def test_make_independent_exception(self, normal_parameter: Parameter): + # When Then Expect + with pytest.raises(AttributeError): + normal_parameter.make_independent() + + def test_independent_setter(self, normal_parameter: Parameter): + # When Then Expect + with pytest.raises(AttributeError): + normal_parameter.independent = False + + def test_independent_parameter_dependency_expression(self, normal_parameter: Parameter): + # When Then Expect + with pytest.raises(AttributeError): + normal_parameter.dependency_expression + + def test_dependent_parameter_dependency_expression_setter(self, normal_parameter: Parameter): + # When + dependent_parameter = Parameter.from_dependency( + name = 'dependent', + dependency_expression='2*a', + dependency_map={'a': normal_parameter}, + ) + + # Then Expect + with pytest.raises(AttributeError): + dependent_parameter.dependency_expression = '3*a' + + def test_independent_parameter_dependency_map(self, normal_parameter: Parameter): + # When Then Expect + with pytest.raises(AttributeError): + normal_parameter.dependency_map + + def test_dependent_parameter_dependency_map_setter(self, normal_parameter: Parameter): + # When + dependent_parameter = Parameter.from_dependency( + name = 'dependent', + dependency_expression='2*a', + dependency_map={'a': normal_parameter}, + ) + + # Then Expect + with pytest.raises(AttributeError): + dependent_parameter.dependency_map = {'a': normal_parameter} + def test_min(self, parameter: Parameter): # When Then Expect assert parameter.min == 0 @@ -108,6 +484,18 @@ def test_set_min(self, parameter: Parameter): # Expect assert parameter.min == 0.1 + def test_set_min_dependent_parameter(self, normal_parameter: Parameter): + # When + dependent_parameter = Parameter.from_dependency( + name = 'dependent', + dependency_expression='2*a', + dependency_map={'a': normal_parameter}, + ) + + # Then Expect + with pytest.raises(AttributeError): + dependent_parameter.min = 0.1 + def test_set_min_exception(self, parameter: Parameter): # When Then Expect with pytest.raises(ValueError): @@ -120,6 +508,18 @@ def test_set_max(self, parameter: Parameter): # Expect assert parameter.max == 10 + def test_set_max_dependent_parameter(self, normal_parameter: Parameter): + # When + dependent_parameter = Parameter.from_dependency( + name = 'dependent', + dependency_expression='2*a', + dependency_map={'a': normal_parameter}, + ) + + # Then Expect + with pytest.raises(AttributeError): + dependent_parameter.max = 10 + def test_set_max_exception(self, parameter: Parameter): # When Then Expect with pytest.raises(ValueError): @@ -135,10 +535,6 @@ def test_convert_unit(self, parameter: Parameter): assert parameter._max.value == 10000 assert parameter._max.unit == "mm" - def test_fixed(self, parameter: Parameter): - # When Then Expect - assert parameter.fixed == False - def test_set_fixed(self, parameter: Parameter): # When Then parameter.fixed = True @@ -146,6 +542,18 @@ def test_set_fixed(self, parameter: Parameter): # Expect assert parameter.fixed == True + def test_set_fixed_dependent_parameter(self, normal_parameter: Parameter): + # When + dependent_parameter = Parameter.from_dependency( + name = 'dependent', + dependency_expression='2*a', + dependency_map={'a': normal_parameter}, + ) + + # Then Expect + with pytest.raises(AttributeError): + dependent_parameter.fixed = True + @pytest.mark.parametrize("fixed", ["True", 1]) def test_set_fixed_exception(self, parameter: Parameter, fixed): # When Then Expect @@ -185,69 +593,6 @@ def test_repr_fixed(self, parameter: Parameter): # Then Expect assert repr(parameter) == "" - def test_bounds(self, parameter: Parameter): - # When Then Expect - assert parameter.bounds == (0, 10) - - def test_set_bounds(self, parameter: Parameter): - # When - self.mock_callback.fget.return_value = 1.0 # Ensure fget returns a scalar value - parameter._enabled = False - parameter._fixed = True - - # Then - parameter.bounds = (-10, 5) - - # Expect - assert parameter.min == -10 - assert parameter.max == 5 - assert parameter._enabled == True - assert parameter._fixed == False - - def test_set_bounds_exception_min(self, parameter: Parameter): - # When - parameter._enabled = False - parameter._fixed = True - - # Then - with pytest.raises(ValueError): - parameter.bounds = (2, 10) - - # Expect - assert parameter.min == 0 - assert parameter.max == 10 - assert parameter._enabled == False - assert parameter._fixed == True - - def test_set_bounds_exception_max(self, parameter: Parameter): - # When - parameter._enabled = False - parameter._fixed = True - - # Then - with pytest.raises(ValueError): - parameter.bounds = (0, 0.1) - - # Expect - assert parameter.min == 0 - assert parameter.max == 10 - assert parameter._enabled == False - assert parameter._fixed == True - - def test_enabled(self, parameter: Parameter): - # When - parameter._enabled = True - - # Then Expect - assert parameter.enabled is True - - def test_set_enabled(self, parameter: Parameter): - # When - parameter.enabled = False - - # Then Expect - assert parameter._enabled is False - def test_value_match_callback(self, parameter: Parameter): # When self.mock_callback.fget.return_value = 1.0 @@ -278,6 +623,18 @@ def test_set_value(self, parameter: Parameter): assert parameter._callback.fset.call_count == 1 assert parameter._scalar == sc.scalar(2, unit='m') + def test_set_value_dependent_parameter(self, normal_parameter: Parameter): + # When + dependent_parameter = Parameter.from_dependency( + name = 'dependent', + dependency_expression='2*a', + dependency_map={'a': normal_parameter}, + ) + + # Then Expect + with pytest.raises(AttributeError): + dependent_parameter.value = 3 + def test_full_value_match_callback(self, parameter: Parameter): # When self.mock_callback.fget.return_value = sc.scalar(1, unit='m') @@ -298,7 +655,31 @@ def test_set_full_value(self, parameter: Parameter): # When Then Expect with pytest.raises(AttributeError): parameter.full_value = sc.scalar(2, unit='s') - + + def test_set_variance_dependent_parameter(self, normal_parameter: Parameter): + # When + dependent_parameter = Parameter.from_dependency( + name = 'dependent', + dependency_expression='2*a', + dependency_map={'a': normal_parameter}, + ) + + # Then Expect + with pytest.raises(AttributeError): + dependent_parameter.variance = 0.1 + + def test_set_error_dependent_parameter(self, normal_parameter: Parameter): + # When + dependent_parameter = Parameter.from_dependency( + name = 'dependent', + dependency_expression='2*a', + dependency_map={'a': normal_parameter}, + ) + + # Then Expect + with pytest.raises(AttributeError): + dependent_parameter.error = 0.1 + def test_copy(self, parameter: Parameter): # When Then self.mock_callback.fget.return_value = 1.0 # Ensure fget returns a scalar value @@ -317,28 +698,7 @@ def test_copy(self, parameter: Parameter): assert parameter_copy._description == parameter._description assert parameter_copy._url == parameter._url assert parameter_copy._display_name == parameter._display_name - assert parameter_copy._enabled == parameter._enabled - - def test_as_data_dict(self, clear, parameter: Parameter): - # When Then - self.mock_callback.fget.return_value = 1.0 # Ensure fget returns a scalar value - parameter_dict = parameter.as_data_dict() - - # Expect - assert parameter_dict == { - "name": "name", - "value": 1.0, - "unit": "m", - "variance": 0.01, - "min": 0, - "max": 10, - "fixed": False, - "description": "description", - "url": "url", - "display_name": "display_name", - "enabled": "enabled", - "unique_name": "Parameter_0", - } + assert parameter_copy._independent == parameter._independent @pytest.mark.parametrize("test, expected, expected_reverse", [ (Parameter("test", 2, "m", 0.01, -10, 20), Parameter("name + test", 3, "m", 0.02, -10, 30), Parameter("test + name", 3, "m", 0.02, -10, 30)),