Skip to content

Commit e6ae923

Browse files
committed
more serialization
1 parent 2e65a01 commit e6ae923

File tree

4 files changed

+55
-7
lines changed

4 files changed

+55
-7
lines changed

src/easyreflectometry/experiment/model.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -206,17 +206,20 @@ def as_dict(self, skip: list = None) -> dict:
206206
return this_dict
207207

208208
@classmethod
209-
def from_dict(cls, data: dict) -> Model:
209+
def from_dict(cls, this_dict: dict) -> Model:
210210
"""
211211
Create a Model from a dictionary.
212212
213-
:param data: dictionary of the Model
213+
:param this_dict: dictionary of the Model
214214
:return: Model
215215
"""
216-
model = super().from_dict(data)
216+
resolution_function = ResolutionFunction.from_dict(this_dict['resolution_function'])
217+
del this_dict['resolution_function']
218+
sample = Sample.from_dict(this_dict['sample'])
219+
del this_dict['sample']
217220

218-
# Ensure that the sample is also converted
219-
# TODO Should probably be handled in easyscience
220-
model.sample = model.sample.__class__.from_dict(data['sample'])
221-
model.resolution_function = ResolutionFunction.from_dict(data['resolution_function'])
221+
model = super().from_dict(this_dict)
222+
223+
model.sample = sample
224+
model.resolution_function = resolution_function
222225
return model

src/easyreflectometry/experiment/model_collection.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
__author__ = 'github.com/arm61'
24

35
from typing import Optional
@@ -39,3 +41,20 @@ def remove_model(self, idx: int):
3941
:param idx: Index of the model to remove
4042
"""
4143
del self[idx]
44+
45+
@classmethod
46+
def from_dict(cls, this_dict: dict) -> ModelCollection:
47+
"""
48+
Create an instance of a collection from a dictionary.
49+
50+
:param data: The dictionary for the collection
51+
:return: An instance of the collection
52+
"""
53+
collection = super().from_dict(this_dict) # type: ModelCollection
54+
55+
if len(collection) != len(this_dict['data']):
56+
raise ValueError(f"Expected {len(collection)} models, got {len(this_dict['data'])}")
57+
for i, model_data in enumerate(this_dict['data']):
58+
collection[i] = Model.from_dict(model_data)
59+
60+
return collection

tests/experiment/test_model_collection.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,23 @@ def test_as_dict(self):
7070

7171
# Expect
7272
assert dict_repr['data'][0]['resolution_function'] == {'smearing': 'PercentageFhwm', 'constant': 5.0}
73+
74+
def test_dict_round_trip(self):
75+
# When
76+
model_1 = Model(name='Model1')
77+
model_2 = Model(name='Model2')
78+
model_3 = Model(name='Model3')
79+
80+
# Then
81+
collection = ModelCollection(model_1, model_2, model_3)
82+
83+
src_dict = collection.as_dict()
84+
85+
# Then
86+
collection_from_dict = ModelCollection.from_dict(src_dict)
87+
88+
# Expect
89+
assert collection.as_data_dict(skip=['resolution_function', 'interface']) == collection_from_dict.as_data_dict(
90+
skip=['resolution_function', 'interface']
91+
)
92+
assert collection[0]._resolution_function.smearing(5.5) == collection_from_dict[0]._resolution_function.smearing(5.5)

tests/sample/elements/materials/test_material_collection.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ def test_from_pars(self):
3030
assert p[0].name == 'Boron'
3131
assert p[1].name == 'Potassium'
3232

33+
def test_empty_list(self):
34+
p = MaterialCollection([])
35+
assert p.name == 'EasyMaterials'
36+
assert p.interface is None
37+
assert len(p) == 0
38+
3339
def test_dict_repr(self):
3440
p = MaterialCollection()
3541
assert p._dict_repr == {

0 commit comments

Comments
 (0)