|
13 | 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 | 14 | # See the License for the specific language governing permissions and
|
15 | 15 | # limitations under the License.
|
| 16 | +import enum |
16 | 17 |
|
17 | 18 | import math
|
18 | 19 | import struct
|
19 | 20 | from typing import Union, Optional
|
20 | 21 |
|
21 | 22 | from scalecodec.base import ScaleType, ScaleBytes, ScalePrimitive, ScaleTypeDef
|
22 | 23 | from scalecodec.constants import TYPE_DECOMP_MAX_RECURSIVE
|
23 |
| -from scalecodec.exceptions import ScaleEncodeException, ScaleDecodeException, ScaleDeserializeException |
| 24 | +from scalecodec.exceptions import ScaleEncodeException, ScaleDecodeException, ScaleDeserializeException, \ |
| 25 | + ScaleSerializeException |
24 | 26 |
|
25 | 27 |
|
26 | 28 | class UnsignedInteger(ScalePrimitive):
|
@@ -230,6 +232,8 @@ def decode(self, data) -> dict:
|
230 | 232 | return value
|
231 | 233 |
|
232 | 234 | def serialize(self, value: dict) -> dict:
|
| 235 | + if value is None: |
| 236 | + raise ScaleSerializeException('Value cannot be None') |
233 | 237 | return {k: obj.value for k, obj in value.items()}
|
234 | 238 |
|
235 | 239 | def deserialize(self, value: dict) -> dict:
|
@@ -400,6 +404,12 @@ def deserialize(self, value: Union[str, dict]) -> tuple:
|
400 | 404 | if type(value) is str:
|
401 | 405 | value = {value: None}
|
402 | 406 |
|
| 407 | + if isinstance(value, enum.Enum): |
| 408 | + value = {value.name: None} |
| 409 | + |
| 410 | + if len(list(value.items())) != 1: |
| 411 | + raise ScaleDeserializeException("Only one variant can be specified for enums") |
| 412 | + |
403 | 413 | enum_key, enum_value = list(value.items())[0]
|
404 | 414 |
|
405 | 415 | for idx, (variant_name, variant_obj) in enumerate(self.variants.items()):
|
|
0 commit comments