Skip to content

163 resolution function to class #164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/easyreflectometry/calculators/refl1d/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Tuple

import numpy as np
from easyreflectometry.experiment.resolution_functions import is_percentage_fhwm_resolution_function
from easyreflectometry.experiment.resolution_functions import PercentageFhwm
from refl1d import model
from refl1d import names

Expand Down Expand Up @@ -145,9 +145,9 @@ def calculate(self, q_array: np.ndarray, model_name: str) -> np.ndarray:
:return: reflectivity calculated at q
"""
sample = _build_sample(self.storage, model_name)
dq_array = self._resolution_function(q_array)
dq_array = self._resolution_function.smearing(q_array)

if is_percentage_fhwm_resolution_function(self._resolution_function):
if isinstance(self._resolution_function, PercentageFhwm):
# Get percentage of Q and change from sigma to FWHM
dq_array = dq_array * q_array / 100 / (2 * np.sqrt(2 * np.log(2)))

Expand Down
6 changes: 3 additions & 3 deletions src/easyreflectometry/calculators/refnx/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Tuple

import numpy as np
from easyreflectometry.experiment.resolution_functions import is_percentage_fhwm_resolution_function
from easyreflectometry.experiment.resolution_functions import PercentageFhwm
from refnx import reflect

from ..wrapper_base import WrapperBase
Expand Down Expand Up @@ -129,8 +129,8 @@ def calculate(self, q_array: np.ndarray, model_name: str) -> np.ndarray:
dq_type='pointwise',
)

dq_vector = self._resolution_function(q_array)
if is_percentage_fhwm_resolution_function(self._resolution_function):
dq_vector = self._resolution_function.smearing(q_array)
if isinstance(self._resolution_function, PercentageFhwm):
# FWHM Percentage resolution is constant given as
# For a constant resolution percentage refnx supports to pass a scalar value rather than a vector
dq_vector = dq_vector[0]
Expand Down
9 changes: 4 additions & 5 deletions src/easyreflectometry/calculators/wrapper_base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from abc import abstractmethod
from typing import Callable

import numpy as np

from easyreflectometry.experiment import DEFAULT_RESOLUTION_FWHM_PERCENTAGE
from easyreflectometry.experiment import percentage_fhwm_resolution_function
from easyreflectometry.experiment import PercentageFhwm
from easyreflectometry.experiment import ResolutionFunction


class WrapperBase:
Expand All @@ -16,7 +15,7 @@ def __init__(self):
'item': {},
'model': {},
}
self._resolution_function = percentage_fhwm_resolution_function(DEFAULT_RESOLUTION_FWHM_PERCENTAGE)
self._resolution_function = PercentageFhwm()

def reset_storage(self):
"""Reset the storage area to blank."""
Expand Down Expand Up @@ -205,7 +204,7 @@ def get_item_value(self, name: str, key: str) -> float:
item = getattr(item, key)
return getattr(item, 'value')

def set_resolution_function(self, resolution_function: Callable[[np.array], np.array]) -> None:
def set_resolution_function(self, resolution_function: ResolutionFunction) -> None:
"""Set the resolution function for the calculator.

:param resolution_function: The resolution function
Expand Down
12 changes: 6 additions & 6 deletions src/easyreflectometry/experiment/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from .model import Model
from .model_collection import ModelCollection
from .resolution_functions import DEFAULT_RESOLUTION_FWHM_PERCENTAGE
from .resolution_functions import linear_spline_resolution_function
from .resolution_functions import percentage_fhwm_resolution_function
from .resolution_functions import LinearSpline
from .resolution_functions import PercentageFhwm
from .resolution_functions import ResolutionFunction

__all__ = (
DEFAULT_RESOLUTION_FWHM_PERCENTAGE,
percentage_fhwm_resolution_function,
linear_spline_resolution_function,
LinearSpline,
PercentageFhwm,
ResolutionFunction,
Model,
ModelCollection,
)
26 changes: 12 additions & 14 deletions src/easyreflectometry/experiment/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,21 @@
__author__ = 'github.com/arm61'

from numbers import Number
from typing import Callable
from typing import Union

import numpy as np
import yaml
from easyscience.Objects.ObjectClasses import BaseObj
from easyscience.Objects.ObjectClasses import Parameter

from easyreflectometry.experiment.resolution_functions import is_percentage_fhwm_resolution_function
from easyreflectometry.parameter_utils import get_as_parameter
from easyreflectometry.sample import BaseAssembly
from easyreflectometry.sample import Layer
from easyreflectometry.sample import LayerCollection
from easyreflectometry.sample import Sample

from .resolution_functions import percentage_fhwm_resolution_function
from .resolution_functions import PercentageFhwm
from .resolution_functions import ResolutionFunction

DEFAULTS = {
'scale': {
Expand Down Expand Up @@ -59,7 +58,7 @@ def __init__(
sample: Union[Sample, None] = None,
scale: Union[Parameter, Number, None] = None,
background: Union[Parameter, Number, None] = None,
resolution_function: Union[Callable[[np.array], float], None] = None,
resolution_function: Union[ResolutionFunction, None] = None,
name: str = 'EasyModel',
interface=None,
):
Expand All @@ -69,15 +68,15 @@ def __init__(
:param scale: Scaling factor of profile.
:param background: Linear background magnitude.
:param name: Name of the model, defaults to 'EasyModel'.
:param resolution_function: Resolution function, defaults to percentage_fhwm_resolution_function.
:param resolution_function: Resolution function, defaults to PercentageFhwm.
:param interface: Calculator interface, defaults to `None`.

"""

if sample is None:
sample = Sample(interface=interface)
if resolution_function is None:
resolution_function = percentage_fhwm_resolution_function(DEFAULTS['resolution']['value'])
resolution_function = PercentageFhwm(DEFAULTS['resolution']['value'])

scale = get_as_parameter('scale', scale, DEFAULTS)
background = get_as_parameter('background', background, DEFAULTS)
Expand All @@ -88,8 +87,6 @@ def __init__(
scale=scale,
background=background,
)
if not callable(resolution_function):
raise ValueError('Resolution function must be a callable.')
self.resolution_function = resolution_function
# Must be set after resolution function
self.interface = interface
Expand Down Expand Up @@ -140,12 +137,12 @@ def remove_item(self, idx: int) -> None:
del self.sample[idx]

@property
def resolution_function(self) -> Callable[[np.array], np.array]:
def resolution_function(self) -> ResolutionFunction:
"""Return the resolution function."""
return self._resolution_function

@resolution_function.setter
def resolution_function(self, resolution_function: Callable[[np.array], np.array]) -> None:
def resolution_function(self, resolution_function: ResolutionFunction) -> None:
"""Set the resolution function for the model."""
self._resolution_function = resolution_function
if self.interface is not None:
Expand Down Expand Up @@ -176,8 +173,8 @@ def uid(self) -> int:
@property
def _dict_repr(self) -> dict[str, dict[str, str]]:
"""A simplified dict representation."""
if is_percentage_fhwm_resolution_function(self._resolution_function):
resolution_value = self._resolution_function([0])[0]
if isinstance(self._resolution_function, PercentageFhwm):
resolution_value = self._resolution_function.as_dict()['constant']
resolution = f'{resolution_value} %'
else:
resolution = 'function of Q'
Expand All @@ -204,7 +201,8 @@ def as_dict(self, skip: list = None) -> dict:
if skip is None:
skip = []
this_dict = super().as_dict(skip=skip)
this_dict['sample'] = self.sample.as_dict()
this_dict['sample'] = self.sample.as_dict(skip=skip)
this_dict['resolution_function'] = self.resolution_function.as_dict()
return this_dict

@classmethod
Expand All @@ -220,5 +218,5 @@ def from_dict(cls, data: dict) -> Model:
# Ensure that the sample is also converted
# TODO Should probably be handled in easyscience
model.sample = model.sample.__class__.from_dict(data['sample'])

model.resolution_function = ResolutionFunction.from_dict(data['resolution_function'])
return model
2 changes: 1 addition & 1 deletion src/easyreflectometry/experiment/model_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
interface=None,
**kwargs,
):
if models is None:
if models == ():
models = [Model(interface=interface) for _ in range(SIZE_DEFAULT_COLLECTION)]
super().__init__(name, interface, *models, **kwargs)
self.interface = interface
Expand Down
59 changes: 32 additions & 27 deletions src/easyreflectometry/experiment/resolution_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,47 +5,52 @@
FWHM = 2.35 * sigma [2 * np.sqrt(2 * np.log(2)) * sigma].
"""

from typing import Callable
from __future__ import annotations

from abc import abstractmethod
from typing import Union

import numpy as np

DEFAULT_RESOLUTION_FWHM_PERCENTAGE = 5.0


def percentage_fhwm_resolution_function(constant: float) -> Callable[[np.array], np.array]:
"""Create a resolution function that is constant across the q range.

:param constant: The constant resolution value.
"""

def _constant(q: Union[np.array, float]) -> np.array:
"""Function that calculates the resolution at a given q value.
class ResolutionFunction:
@abstractmethod
def smearing(q: Union[np.array, float]) -> np.array: ...

The function uses the data points from the encapsulating function and produces a linearly interpolated between them.
"""
return np.ones(np.array(q).size) * constant
@abstractmethod
def as_dict() -> dict: ...

return _constant
@classmethod
def from_dict(cls, data: dict) -> ResolutionFunction:
if data['smearing'] == 'PercentageFhwm':
return PercentageFhwm(data['constant'])
if data['smearing'] == 'LinearSpline':
return LinearSpline(data['q_data_points'], data['fwhm_values'])
raise ValueError('Unknown resolution function type')


def linear_spline_resolution_function(q_data_points: np.array, fwhm_values: np.array) -> Callable[[np.array], np.array]:
"""Create a resolution function that is linearly interpolated between given data points.
class PercentageFhwm:
def __init__(self, constant: Union[None, float] = None):
if constant is None:
constant = DEFAULT_RESOLUTION_FWHM_PERCENTAGE
self.constant = constant

:param q_data_points: The q values at which the resolution is defined.
:param fwhm_values: The resolution values at the given q values.
"""
def smearing(self, q: Union[np.array, float]) -> np.array:
return np.ones(np.array(q).size) * self.constant

def _linear(q: np.array) -> np.array:
"""Function that calculates the resolution at a given q value.
def as_dict(self) -> dict:
return {'smearing': 'PercentageFhwm', 'constant': self.constant}

The function uses the data points from the encapsulating function and produces a linearly interpolated between them.
"""
return np.interp(q, q_data_points, fwhm_values)

return _linear
class LinearSpline:
def __init__(self, q_data_points: np.array, fwhm_values: np.array):
self.q_data_points = q_data_points
self.fwhm_values = fwhm_values

def smearing(self, q: Union[np.array, float]) -> np.array:
return np.interp(q, self.q_data_points, self.fwhm_values)

def is_percentage_fhwm_resolution_function(resolution_function: Callable[[np.array], np.array]) -> bool:
"""Check if the resolution function is a constant."""
return 'constant' in resolution_function.__name__
def as_dict(self) -> dict:
return {'smearing': 'LinearSpline', 'q_data_points': self.q_data_points, 'fwhm_values': self.fwhm_values}
14 changes: 14 additions & 0 deletions src/easyreflectometry/sample/base_element_collection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any
from typing import List
from typing import Optional

import yaml
from easyscience.Objects.Groups import BaseCollection
Expand Down Expand Up @@ -48,6 +50,18 @@ def _dict_repr(self) -> dict:
"""
return {self.name: [i._dict_repr for i in self]}

def as_dict(self, skip: Optional[List[str]] = None) -> dict:
"""
Create a dictionary representation of the collection.

:return: A dictionary representation of the collection
"""
this_dict = super().as_dict(skip=skip)
this_dict['data'] = []
for collection_element in self:
this_dict['data'].append(collection_element.as_dict(skip=skip))
return this_dict

@classmethod
def from_dict(cls, data: dict) -> Any:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/easyreflectometry/sample/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def as_dict(self, skip: list = None) -> dict:
skip = []
this_dict = super().as_dict(skip=skip)
for i, layer in enumerate(self.data):
this_dict['data'][i] = layer.as_dict()
this_dict['data'][i] = layer.as_dict(skip=skip)
return this_dict

@classmethod
Expand Down
12 changes: 6 additions & 6 deletions tests/calculators/refnx/test_refnx_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import numpy as np
from easyreflectometry.calculators.refnx.wrapper import RefnxWrapper
from easyreflectometry.experiment import linear_spline_resolution_function
from easyreflectometry.experiment import percentage_fhwm_resolution_function
from easyreflectometry.experiment import LinearSpline
from easyreflectometry.experiment import PercentageFhwm
from numpy.testing import assert_allclose
from numpy.testing import assert_almost_equal
from numpy.testing import assert_equal
Expand Down Expand Up @@ -318,7 +318,7 @@ def test_calculate_github_test0(self):
p.add_item('Item2', 'MyModel')
p.add_item('Item3', 'MyModel')
p.add_item('Item4', 'MyModel')
p.set_resolution_function(percentage_fhwm_resolution_function(0))
p.set_resolution_function(PercentageFhwm(0))
p.update_model('MyModel', bkg=0)
q = np.array(
[
Expand Down Expand Up @@ -356,7 +356,7 @@ def test_calculate_github_test2(self):
p.add_layer_to_item('Layer2', 'Item2')
p.add_item('Item1', 'MyModel')
p.add_item('Item2', 'MyModel')
p.set_resolution_function(percentage_fhwm_resolution_function(0))
p.set_resolution_function(PercentageFhwm(0))
p.update_model('MyModel', bkg=0)
q = np.array(
[
Expand Down Expand Up @@ -410,7 +410,7 @@ def test_calculate_github_test4_constant_resolution(self):
p.add_item('Item2', 'MyModel')
p.add_item('Item3', 'MyModel')
p.add_item('Item4', 'MyModel')
p.set_resolution_function(percentage_fhwm_resolution_function(5))
p.set_resolution_function(PercentageFhwm(5))
p.update_model('MyModel', bkg=0)
assert_allclose(p.calculate(test4_dat[:, 0], 'MyModel'), test4_dat[:, 1], rtol=0.03)

Expand Down Expand Up @@ -450,7 +450,7 @@ def test_calculate_github_test4_spline_resolution(self):
p.add_item('Item4', 'MyModel')
p.update_model('MyModel', bkg=0)
sigma_to_fhwm = 2.355
p.set_resolution_function(linear_spline_resolution_function(test4_dat[:, 0], sigma_to_fhwm * test4_dat[:, 3]))
p.set_resolution_function(LinearSpline(test4_dat[:, 0], sigma_to_fhwm * test4_dat[:, 3]))
assert_allclose(p.calculate(test4_dat[:, 0], 'MyModel'), test4_dat[:, 1], rtol=0.03)


Expand Down
Loading
Loading