Skip to content

Commit dae6c36

Browse files
committed
Third step in transitioning to pydantic: Atom Sites
1 parent 76c5d3c commit dae6c36

File tree

7 files changed

+222
-100
lines changed

7 files changed

+222
-100
lines changed

src/easydiffraction/analysis/analysis.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from easydiffraction.analysis.minimizers.minimizer_factory import MinimizerFactory
1616
from easydiffraction.core.parameters import Descriptor
1717
from easydiffraction.core.parameters import Parameter
18+
from easydiffraction.core.parameters2 import DescriptorStr
19+
from easydiffraction.core.parameters2 import Parameter as Parameter2
1820
from easydiffraction.core.singletons import ConstraintsHandler
1921
from easydiffraction.experiments.experiments import Experiments
2022
from easydiffraction.utils.formatting import paragraph
@@ -51,7 +53,7 @@ def _get_params_as_dataframe(
5153
rows = []
5254
for param in params:
5355
common_attrs = {}
54-
if isinstance(param, (Descriptor, Parameter)):
56+
if isinstance(param, (Descriptor, Parameter, Parameter2, DescriptorStr)):
5557
common_attrs = {
5658
'datablock': param.datablock_name,
5759
'category': param.category_key,
@@ -62,7 +64,7 @@ def _get_params_as_dataframe(
6264
'fittable': False,
6365
}
6466
param_attrs = {}
65-
if isinstance(param, Parameter):
67+
if isinstance(param, (Parameter, Parameter2)):
6668
param_attrs = {
6769
'fittable': True,
6870
'free': param.free,

src/easydiffraction/core/categories.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,13 @@ def __setattr__(self, key: str, value: Any) -> None:
118118
except AttributeError:
119119
attr = self._MISSING_ATTR
120120
# If replacing or assigning any descriptor/parameter instance
121-
if isinstance(value, (Descriptor, Parameter)):
121+
if isinstance(value, (Descriptor, Parameter, Parameter2, DescriptorStr)):
122122
value._parent = self
123123
object.__setattr__(self, key, value)
124124
# Dealing with existing descriptor/parameter instance
125-
elif attr is not self._MISSING_ATTR and isinstance(attr, (Descriptor, Parameter)):
125+
elif attr is not self._MISSING_ATTR and isinstance(
126+
attr, (Descriptor, Parameter, Parameter2, DescriptorStr)
127+
):
126128
# Special pre-handling for category entry name attribute
127129
if key == self._category_entry_attr_name:
128130
old_name = self.category_entry_name
@@ -316,7 +318,7 @@ def as_cif(self) -> str:
316318
if attr_name.startswith('_'):
317319
continue
318320
attr_obj = getattr(first_item, attr_name)
319-
if not isinstance(attr_obj, (Descriptor, Parameter)):
321+
if not isinstance(attr_obj, (Descriptor, Parameter, Parameter2, DescriptorStr)):
320322
continue
321323
tags = getattr(attr_obj, '_full_cif_names', []) or []
322324
if not tags:

src/easydiffraction/core/parameters2.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def __init__(self, default: Any):
3636
def validate(self, name, new, current):
3737
raise NotImplementedError()
3838

39-
@property
40-
def value_type(self):
41-
return type(self.default)
39+
# @property
40+
# def value_type(self):
41+
# return type(self.default)
4242

4343

4444
class RangeValidator(Validator):
@@ -69,9 +69,9 @@ def validate(self, name, new, current):
6969
log.debug(f'{name} set to validated: {new}.')
7070
return new
7171

72-
@property
73-
def value_type(self):
74-
return float
72+
# @property
73+
# def value_type(self):
74+
# return float
7575

7676

7777
class ListValidator(Validator):
@@ -80,28 +80,52 @@ def __init__(
8080
allowed_values,
8181
default=None,
8282
):
83-
self.allowed_values = allowed_values() if callable(allowed_values) else allowed_values
84-
if default is None:
85-
default = self.allowed_values[0] if self.allowed_values else None
83+
self._allowed_values = allowed_values
84+
self._default = default
8685
super().__init__(default)
8786

87+
@property
88+
def allowed_values(self):
89+
return self._allowed_values() if callable(self._allowed_values) else self._allowed_values
90+
91+
@allowed_values.setter
92+
def allowed_values(self, value):
93+
self._allowed_values = value
94+
95+
@property
96+
def default(self):
97+
return self._default() if callable(self._default) else self._default
98+
99+
@default.setter
100+
def default(self, value):
101+
self._default = value
102+
88103
def validate(self, name, new, current):
104+
allowed = self.allowed_values
89105
if current is None and new is None:
106+
# log.debug(f"{name} set to default: '{self.default}'.")
107+
# return self.default
90108
log.debug(f"{name} set to default: '{self.default}'.")
91-
return self.default
92-
if new not in self.allowed_values:
93-
new = f"'{new}'" if isinstance(new, str) else new
94-
message = f'{new} is not allowed for {name}.'
95-
log.warning(message, exc_type=UserWarning)
96-
return current if current is not None else self.default
109+
return self.default or (allowed[0] if allowed else None)
110+
if new not in allowed:
111+
# new = f"'{new}'" if isinstance(new, str) else new
112+
# message = f'{new} is not allowed for {name}.'
113+
# log.warning(message, exc_type=UserWarning)
114+
# return current if current is not None else self.default
115+
log.warning(f'{new!r} is not allowed for {name}.', exc_type=UserWarning)
116+
return (
117+
current
118+
if current is not None
119+
else (self.default or (allowed[0] if allowed else None))
120+
)
97121
log.debug(f"{name} set to validated: '{new}'.")
98122
return new
99123

100-
@property
101-
def value_type(self):
102-
if self.allowed_values:
103-
return type(self.allowed_values[0])
104-
return str
124+
# @property
125+
# def value_type(self):
126+
# if self.allowed_values:
127+
# return type(self.allowed_values[0])
128+
# return str
105129

106130

107131
class RegexValidator(Validator):

0 commit comments

Comments
 (0)