Skip to content

Commit 2e65a01

Browse files
committed
ability to serialize model_collection
1 parent a734e8d commit 2e65a01

File tree

5 files changed

+89
-3
lines changed

5 files changed

+89
-3
lines changed

src/easyreflectometry/experiment/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def as_dict(self, skip: list = None) -> dict:
201201
if skip is None:
202202
skip = []
203203
this_dict = super().as_dict(skip=skip)
204-
this_dict['sample'] = self.sample.as_dict()
204+
this_dict['sample'] = self.sample.as_dict(skip=skip)
205205
this_dict['resolution_function'] = self.resolution_function.as_dict()
206206
return this_dict
207207

src/easyreflectometry/experiment/model_collection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(
1919
interface=None,
2020
**kwargs,
2121
):
22-
if models is None:
22+
if models == ():
2323
models = [Model(interface=interface) for _ in range(SIZE_DEFAULT_COLLECTION)]
2424
super().__init__(name, interface, *models, **kwargs)
2525
self.interface = interface

src/easyreflectometry/sample/base_element_collection.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from typing import Any
2+
from typing import List
3+
from typing import Optional
24

35
import yaml
46
from easyscience.Objects.Groups import BaseCollection
@@ -48,6 +50,18 @@ def _dict_repr(self) -> dict:
4850
"""
4951
return {self.name: [i._dict_repr for i in self]}
5052

53+
def as_dict(self, skip: Optional[List[str]] = None) -> dict:
54+
"""
55+
Create a dictionary representation of the collection.
56+
57+
:return: A dictionary representation of the collection
58+
"""
59+
this_dict = super().as_dict(skip=skip)
60+
this_dict['data'] = []
61+
for collection_element in self:
62+
this_dict['data'].append(collection_element.as_dict(skip=skip))
63+
return this_dict
64+
5165
@classmethod
5266
def from_dict(cls, data: dict) -> Any:
5367
"""

src/easyreflectometry/sample/sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def as_dict(self, skip: list = None) -> dict:
6969
skip = []
7070
this_dict = super().as_dict(skip=skip)
7171
for i, layer in enumerate(self.data):
72-
this_dict['data'][i] = layer.as_dict()
72+
this_dict['data'][i] = layer.as_dict(skip=skip)
7373
return this_dict
7474

7575
@classmethod
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import unittest
2+
3+
from easyreflectometry.experiment.model import Model
4+
from easyreflectometry.experiment.model_collection import ModelCollection
5+
6+
7+
class TestModelCollection(unittest.TestCase):
8+
def test_default(self):
9+
# When Then
10+
collection = ModelCollection()
11+
12+
# Expect
13+
assert collection.name == 'EasyModels'
14+
assert collection.interface is None
15+
assert len(collection) == 2
16+
assert collection[0].name == 'EasyModel'
17+
assert collection[1].name == 'EasyModel'
18+
19+
def test_from_pars(self):
20+
# When
21+
model_1 = Model(name='Model1')
22+
model_2 = Model(name='Model2')
23+
model_3 = Model(name='Model3')
24+
25+
# Then
26+
collection = ModelCollection(model_1, model_2, model_3)
27+
28+
# Expect
29+
assert collection.name == 'EasyModels'
30+
assert collection.interface is None
31+
assert len(collection) == 3
32+
assert collection[0].name == 'Model1'
33+
assert collection[1].name == 'Model2'
34+
assert collection[2].name == 'Model3'
35+
36+
def test_add_model(self):
37+
# When
38+
model_1 = Model(name='Model1')
39+
model_2 = Model(name='Model2')
40+
41+
# Then
42+
collection = ModelCollection(model_1)
43+
collection.add_model(model_2)
44+
45+
# Expect
46+
assert len(collection) == 2
47+
assert collection[0].name == 'Model1'
48+
assert collection[1].name == 'Model2'
49+
50+
def test_delete_model(self):
51+
# When
52+
model_1 = Model(name='Model1')
53+
model_2 = Model(name='Model2')
54+
55+
# Then
56+
collection = ModelCollection(model_1, model_2)
57+
collection.remove_model(0)
58+
59+
# Expect
60+
assert len(collection) == 1
61+
assert collection[0].name == 'Model2'
62+
63+
def test_as_dict(self):
64+
# When
65+
model_1 = Model(name='Model1')
66+
collection = ModelCollection(model_1)
67+
68+
# Then
69+
dict_repr = collection.as_dict()
70+
71+
# Expect
72+
assert dict_repr['data'][0]['resolution_function'] == {'smearing': 'PercentageFhwm', 'constant': 5.0}

0 commit comments

Comments
 (0)