Skip to content

Commit 5542655

Browse files
authored
ScaleSerializable mixin (#125)
* ScaleSerializable mixin * Drop 3.7 support
1 parent 193c5dd commit 5542655

File tree

4 files changed

+360
-2
lines changed

4 files changed

+360
-2
lines changed

.github/workflows/unittests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
runs-on: ubuntu-latest
1414
strategy:
1515
matrix:
16-
python-version: ['3.7', '3.8', '3.9', '3.10']
16+
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
1717

1818
steps:
1919
- uses: actions/checkout@v2

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ authors = [
1111
{ name = "Polkascan Foundation", email = "[email protected]" }
1212
]
1313

14-
requires-python = ">=3.6, <4"
14+
requires-python = ">=3.8, <4"
1515
classifiers = [
1616
"Development Status :: 5 - Production/Stable",
1717
"Intended Audience :: Developers",

scalecodec/mixins.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
import dataclasses
2+
from dataclasses import is_dataclass
3+
import enum
4+
from typing import Type, TypeVar, Union
5+
import typing
6+
import json
7+
8+
from scalecodec.base import ScaleTypeDef, ScaleType, ScaleBytes
9+
from scalecodec.types import Struct, Option, Vec, Enum
10+
11+
T = TypeVar('T')
12+
13+
14+
class ScaleSerializable:
15+
@classmethod
16+
def scale_type_def(cls) -> ScaleTypeDef:
17+
if is_dataclass(cls):
18+
19+
arguments = {}
20+
for field in dataclasses.fields(cls):
21+
arguments[field.name] = cls.dataclass_field_to_scale_typ_def(field)
22+
23+
return Struct(**arguments)
24+
elif issubclass(cls, enum.Enum):
25+
variants = {status.name: None for status in cls}
26+
return Enum(**variants)
27+
28+
raise NotImplementedError
29+
30+
def serialize(self) -> Union[str, int, float, bool, dict, list]:
31+
scale_type = self.to_scale_type()
32+
return scale_type.serialize()
33+
34+
@classmethod
35+
def deserialize(cls: Type[T], data: Union[str, int, float, bool, dict, list]) -> T:
36+
scale_type = cls.scale_type_def().new()
37+
scale_type.deserialize(data)
38+
return cls.from_scale_type(scale_type)
39+
40+
def to_scale_type(self) -> ScaleType:
41+
42+
if not is_dataclass(self) and not issubclass(self.__class__, enum.Enum):
43+
raise NotImplementedError("Type not supported.")
44+
45+
scale_type = self.scale_type_def().new()
46+
47+
if issubclass(self.__class__, enum.Enum):
48+
scale_type.deserialize(self.name)
49+
elif is_dataclass(self):
50+
value = {}
51+
for field in dataclasses.fields(self):
52+
53+
actual_type = field.type
54+
field_name = field.name[:-1] if field.name.endswith('_') else field.name
55+
56+
if typing.get_origin(actual_type) is typing.Union:
57+
# Extract the arguments of the Union type
58+
args = typing.get_args(actual_type)
59+
if type(None) in args:
60+
# If NoneType is in the args, it's an Optional
61+
actual_type = [arg for arg in args if arg is not type(None)][0]
62+
63+
if getattr(self, field.name) is None:
64+
value[field_name] = None
65+
else:
66+
67+
if typing.get_origin(actual_type) is list:
68+
actual_type = typing.get_args(actual_type)[0]
69+
70+
if issubclass(actual_type, ScaleSerializable):
71+
value[field_name] = [i.serialize() for i in getattr(self, field.name)]
72+
else:
73+
value[field_name] = getattr(self, field.name)
74+
75+
# TODO too simplified now
76+
elif issubclass(actual_type, ScaleSerializable):
77+
78+
value[field_name] = getattr(self, field.name).serialize()
79+
else:
80+
value[field_name] = getattr(self, field.name)
81+
82+
scale_type.deserialize(value)
83+
84+
return scale_type
85+
86+
@classmethod
87+
def from_scale_type(cls: Type[T], scale_type: ScaleType) -> T:
88+
if is_dataclass(cls):
89+
90+
fields = {}
91+
92+
for field in dataclasses.fields(cls):
93+
94+
scale_field_name = field.name[:-1] if field.name.endswith('_') else field.name
95+
96+
actual_type = field.type
97+
98+
if typing.get_origin(field.type) is typing.Union:
99+
# Extract the arguments of the Union type
100+
args = typing.get_args(field.type)
101+
if type(None) in args:
102+
# If NoneType is in the args, it's an Optional
103+
if field.name in scale_type.value:
104+
if scale_type.value[field.name] is None:
105+
fields[field.name] = None
106+
continue
107+
else:
108+
actual_type = [arg for arg in args if arg is not type(None)][0]
109+
else:
110+
# print(field.name)
111+
continue
112+
113+
if typing.get_origin(actual_type) is list:
114+
items = []
115+
actual_type = typing.get_args(actual_type)[0]
116+
117+
if issubclass(type(scale_type.type_def), (Struct, Option)):
118+
list_items = scale_type.value_object[scale_field_name].value_object
119+
elif issubclass(type(scale_type.type_def), (Vec, Enum)):
120+
list_items = scale_type.value_object[1].value_object
121+
else:
122+
raise ValueError(f'Unsupported type: {type(scale_type.type_def)}')
123+
124+
for item in list_items:
125+
if actual_type in [str, int, float, bool]:
126+
items.append(item.value)
127+
elif actual_type is bytes:
128+
items.append(item.to_bytes())
129+
elif is_dataclass(actual_type):
130+
items.append(actual_type.from_scale_type(item))
131+
132+
fields[field.name] = items
133+
134+
elif actual_type in [str, int, float, bool]:
135+
fields[field.name] = scale_type.value[scale_field_name]
136+
elif actual_type is bytes:
137+
fields[field.name] = scale_type.value_object[scale_field_name].to_bytes()
138+
elif is_dataclass(actual_type):
139+
try:
140+
141+
# TODO unwrap Option
142+
if issubclass(type(scale_type.type_def), (Struct, Option)):
143+
144+
field_scale_type = scale_type.value_object[scale_field_name]
145+
elif issubclass(type(scale_type.type_def), Enum):
146+
field_scale_type = scale_type.value_object[1]
147+
else:
148+
raise ValueError(f"Unexpected type {type(scale_type.type_def)}")
149+
150+
fields[field.name] = actual_type.from_scale_type(field_scale_type)
151+
except (KeyError, TypeError) as e:
152+
print('oeps', str(e))
153+
elif issubclass(actual_type, enum.Enum):
154+
fields[field.name] = actual_type[scale_type.value_object[1].value]
155+
return cls(**fields)
156+
raise NotImplementedError
157+
158+
def to_scale_bytes(self) -> ScaleBytes:
159+
scale_obj = self.to_scale_type()
160+
return scale_obj.encode()
161+
162+
@classmethod
163+
def from_scale_bytes(cls: Type[T], scale_bytes: ScaleBytes) -> T:
164+
scale_obj = cls.scale_type_def().new()
165+
scale_obj.decode(scale_bytes)
166+
return cls.from_scale_type(scale_obj)
167+
168+
def to_json(self) -> str:
169+
return json.dumps(self.serialize(), indent=4)
170+
171+
@classmethod
172+
def from_json(cls: Type[T], json_data: str) -> T:
173+
# data = json.loads(json_data)
174+
return cls.deserialize(json_data)
175+
176+
@classmethod
177+
def dataclass_field_to_scale_typ_def(cls, field) -> ScaleTypeDef:
178+
179+
if 'scale' in field.metadata:
180+
return field.metadata['scale']
181+
182+
# Check if the field type is an instance of Optional
183+
actual_type = field.type
184+
wrap_option = False
185+
wrap_vec = False
186+
187+
if typing.get_origin(field.type) is typing.Union:
188+
# Extract the arguments of the Union type
189+
args = typing.get_args(field.type)
190+
if type(None) in args:
191+
# If NoneType is in the args, it's an Optional
192+
wrap_option = True
193+
actual_type = [arg for arg in args if arg is not type(None)][0]
194+
# print(f"The field '{field.name}' is Optional with inner type: {actual_type}")
195+
196+
if typing.get_origin(actual_type) is list:
197+
wrap_vec = True
198+
actual_type = typing.get_args(actual_type)[0]
199+
200+
if is_dataclass(actual_type):
201+
if issubclass(actual_type, ScaleSerializable):
202+
scale_def = actual_type.scale_type_def()
203+
else:
204+
raise ValueError(f"Cannot serialize dataclass {field.type.__class__}")
205+
206+
elif actual_type is bytes:
207+
raise ValueError("bytes is ambiguous; specify SCALE type def in metadata e.g. {'scale': H256}")
208+
elif actual_type is int:
209+
raise ValueError("int is ambiguous; specify SCALE type def in metadata e.g. {'scale': U32}")
210+
211+
elif issubclass(actual_type, enum.Enum):
212+
variants = {status.name: None for status in actual_type}
213+
scale_def = Enum(**variants)
214+
215+
else:
216+
raise ValueError(f"Cannot convert {actual_type} to ScaleTypeDef")
217+
218+
if wrap_vec:
219+
scale_def = Vec(scale_def)
220+
if wrap_option:
221+
scale_def = Option(scale_def)
222+
223+
return scale_def

test/test_mixins.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import enum
2+
import json
3+
import os
4+
import unittest
5+
from dataclasses import dataclass, field
6+
from os import path
7+
from typing import Optional, Type, Union, List
8+
9+
from scalecodec.base import ScaleBytes, ScaleType
10+
from scalecodec.mixins import ScaleSerializable, T
11+
from scalecodec.types import H256, U8, Array, Enum
12+
13+
14+
# Test definitions
15+
16+
17+
@dataclass
18+
class ValidatorData(ScaleSerializable):
19+
bandersnatch: bytes = field(metadata={'scale': H256})
20+
ed25519: bytes = field(metadata={'scale': H256})
21+
bls: bytes = field(metadata={'scale': Array(U8, 144)})
22+
metadata: bytes = field(metadata={'scale': Array(U8, 128)})
23+
24+
25+
@dataclass
26+
class EpochMark(ScaleSerializable):
27+
entropy: bytes = field(metadata={'scale': H256})
28+
validators: List[bytes] = field(metadata={'scale': Array(H256, 6)})
29+
30+
@dataclass
31+
class OutputMarks(ScaleSerializable):
32+
epoch_mark: Optional[EpochMark] = None
33+
34+
35+
class CustomErrorCode(ScaleSerializable, enum.Enum):
36+
bad_slot = 0 # Timeslot value must be strictly monotonic.
37+
unexpected_ticket = 1 # Received a ticket while in epoch's tail.
38+
bad_ticket_order = 2 # Tickets must be sorted.
39+
bad_ticket_proof = 3 # Invalid ticket ring proof.
40+
bad_ticket_attempt = 4 # Invalid ticket attempt value.
41+
reserved = 5 # Reserved
42+
duplicate_ticket = 6 # Found a ticket duplicate.
43+
too_many_tickets = 7 # Found amount of tickets > K
44+
45+
46+
@dataclass
47+
class Output(ScaleSerializable):
48+
ok: Optional[OutputMarks] = None # Markers
49+
err: Optional[CustomErrorCode] = None
50+
51+
@classmethod
52+
def scale_type_def(cls):
53+
54+
return Enum(
55+
ok=OutputMarks.scale_type_def(),
56+
err=CustomErrorCode.scale_type_def()
57+
)
58+
59+
def to_scale_type(self) -> ScaleType:
60+
scale_type = self.scale_type_def().new()
61+
scale_type.deserialize(self.serialize())
62+
return scale_type
63+
64+
@classmethod
65+
def deserialize(cls: Type[T], data: Union[str, int, float, bool, dict, list]) -> T:
66+
67+
return super().deserialize(data)
68+
69+
def serialize(self) -> Union[str, int, float, bool, dict, list]:
70+
if self.err is not None:
71+
return {'err': self.err.serialize()}
72+
else:
73+
return {'ok': self.ok.serialize()}
74+
75+
76+
class TestSerializableMixin(unittest.TestCase):
77+
78+
def setUp(self):
79+
data = {
80+
'bandersnatch': '0x5e465beb01dbafe160ce8216047f2155dd0569f058afd52dcea601025a8d161d',
81+
'ed25519': '0x3b6a27bcceb6a42d62a3a8d02a6f0d73653215771de243a63ac048a18b59da29',
82+
'bls': '0x000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000',
83+
'metadata': '0x0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000'
84+
}
85+
86+
self.test_obj = ValidatorData.deserialize(data)
87+
88+
def test_dataclass_serialization(self):
89+
output = Output(ok=OutputMarks(epoch_mark=None))
90+
value = output.serialize()
91+
self.assertEqual({'ok': {'epoch_mark': None}}, value)
92+
93+
output = Output(err=CustomErrorCode.duplicate_ticket)
94+
value = output.serialize()
95+
96+
self.assertEqual({'err': 'duplicate_ticket'}, value)
97+
98+
def test_dataclass_to_scale_type(self):
99+
output = Output(
100+
ok=OutputMarks(
101+
epoch_mark=EpochMark(
102+
entropy=bytes(32),
103+
validators=[bytes(32), bytes(32), bytes(32), bytes(32), bytes(32), bytes(32)]
104+
)
105+
)
106+
)
107+
scale_type = output.to_scale_type()
108+
output2 = Output.from_scale_type(scale_type)
109+
self.assertEqual(output, output2)
110+
111+
def test_deserialize(self):
112+
113+
data = {
114+
'bandersnatch': '0x5e465beb01dbafe160ce8216047f2155dd0569f058afd52dcea601025a8d161d',
115+
'ed25519': '0x3b6a27bcceb6a42d62a3a8d02a6f0d73653215771de243a63ac048a18b59da29',
116+
'bls': '0x000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000',
117+
'metadata': '0x0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000'
118+
}
119+
120+
validator_obj = ValidatorData.deserialize(data)
121+
122+
self.assertEqual(self.test_obj, validator_obj)
123+
self.assertEqual(data, validator_obj.serialize())
124+
125+
def test_from_to_scale_bytes(self):
126+
127+
scale_data = self.test_obj.to_scale_bytes()
128+
129+
validator_obj = ValidatorData.from_scale_bytes(scale_data)
130+
131+
self.assertEqual(self.test_obj, validator_obj)
132+
133+
134+
if __name__ == '__main__':
135+
unittest.main()

0 commit comments

Comments
 (0)