Skip to content

Commit ad7e5fe

Browse files
authored
Merge pull request #192 from CitrineInformatics/feature/robust-enum
Augment base enumeration class to behave well in string contexts
2 parents c08ec69 + 4fc5b21 commit ad7e5fe

File tree

8 files changed

+131
-41
lines changed

8 files changed

+131
-41
lines changed

gemd/entity/attribute/base_attribute.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(self,
4949
self._value = None
5050
self._template = None
5151
self._origin = None
52-
self._file_links = None,
52+
self._file_links = None
5353

5454
self.value = value
5555
self.template = template
@@ -106,15 +106,15 @@ def _template_type() -> Type:
106106
"""Get the expected type of template for this object (property of child)."""
107107

108108
@property
109-
def origin(self) -> str:
109+
def origin(self) -> Origin:
110110
"""Get origin."""
111111
return self._origin
112112

113113
@origin.setter
114114
def origin(self, origin: Union[Origin, str]):
115115
if origin is None:
116116
raise ValueError("origin must be specified (but may be `unknown`)")
117-
self._origin = Origin.get_value(origin)
117+
self._origin = Origin.from_str(origin, exception=True)
118118

119119
@property
120120
def file_links(self) -> List[FileLink]:

gemd/entity/object/material_run.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,13 @@ def measurements(self) -> List["MeasurementRun"]:
9696
return self._measurements
9797

9898
@property
99-
def sample_type(self) -> str:
99+
def sample_type(self) -> SampleType:
100100
"""Get the sample type."""
101101
return self._sample_type
102102

103103
@sample_type.setter
104104
def sample_type(self, sample_type: Union[SampleType, str]):
105-
self._sample_type = SampleType.get_value(sample_type)
105+
self._sample_type = SampleType.from_str(sample_type, exception=True)
106106

107107
@staticmethod
108108
def _spec_type() -> Type:

gemd/entity/setters.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
T = TypeVar('T')
77

88

9-
def validate_list(obj: Union[Iterable[T], T],
9+
def validate_list(obj: Optional[Union[Iterable[T], T]],
1010
typ: Union[Iterable[Type], Type],
1111
*,
1212
trigger: Callable[[T], Optional[T]] = None) -> ValidList:

gemd/enumeration/base_enumeration.py

+78-20
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,79 @@
11
"""Base class for all enumerations."""
2+
from deprecation import deprecated
23
from enum import Enum
4+
from typing import Optional
35

46

5-
class BaseEnumeration(Enum):
6-
"""Enumeration class that can convert between enumerations and associated values."""
7+
class BaseEnumeration(str, Enum):
8+
"""
9+
Enumeration class that can convert between enumerations and associated values.
710
8-
def __init__(self, *_):
9-
"""Ensure that there are no duplicates in the enumeration."""
10-
cls = self.__class__
11-
if any(self.value == e.value for e in cls):
12-
raise ValueError("Duplicates not allowed in enumerated set of values {}".format(cls))
13-
if not isinstance(self.value, str):
11+
BaseEnumeration is a powerful support class for string enumerations. It inherits
12+
from both str and Enum to enable a class with str capabilities but still a
13+
restricted data space. All constructors are case-insensitive on input and a given
14+
enumeration can recognize multiple synonyms for input, though only one value will
15+
correspond to the value itsself. For example:
16+
17+
```
18+
Fruits(BaseEnumeration):
19+
APPLE = "Apple"
20+
AVOCADO = "Avocado", "Alligator Pear"
21+
```
22+
23+
will recognize "apple", "APPLE" and " aPpLe " as referring to Fruits.APPLE,
24+
and "avocado" and "alligator pear" as referring to Fruits.AVOCADO. However,
25+
since str(Fruits.AVOCADO) is "Avocado", Fruits.AVOCADO != "Alligator Pear".
26+
27+
"""
28+
29+
def __new__(cls, value: str, *args):
30+
"""Overloaded to allow for synonyms."""
31+
if any(not isinstance(x, str) for x in (value,) + args):
1432
raise ValueError("All values of enum {} must be strings".format(cls))
33+
if cls.from_str(value, exception=False) is not None:
34+
raise ValueError("Duplicates not allowed in enumerated set of values {}".format(cls))
35+
obj = str.__new__(cls, value)
36+
obj._value_ = value
37+
obj.synonyms = frozenset(args)
38+
obj.matches = frozenset([obj.lower()]).union(x.lower() for x in obj.synonyms)
39+
return obj
1540

1641
@classmethod
17-
def get_value(cls, name):
42+
def from_str(cls, val: str, *, exception: bool = False) -> Optional["BaseEnumeration"]:
43+
"""
44+
Given a string value, return the Enumeration object that matches.
45+
46+
Parameters
47+
----------
48+
val: str
49+
The string to match against. Leading and trailing whitespace is ignored.
50+
Case is ignored.
51+
exception: bool
52+
Whether to raise an error if the string doesn't match anything. Default: False.
53+
54+
Returns
55+
-------
56+
BaseEnumeration
57+
The matching enumerated element, or None
58+
59+
:param val:
60+
:param exception:
61+
:return:
62+
63+
"""
64+
if val is None:
65+
result = None
66+
else:
67+
result = next((x for x in cls if str.lower(val).strip() in x.matches), None)
68+
if exception and result is None:
69+
raise ValueError(f"{val} is not a valid {cls}; valid choices are {[x for x in cls]}")
70+
return result
71+
72+
@classmethod
73+
@deprecated(deprecated_in="1.15.0",
74+
removed_in="2.0.0",
75+
details="Enumerations autocast to values now.")
76+
def get_value(cls, name: str) -> str:
1877
"""
1978
Return a valid value associated with name.
2079
@@ -23,14 +82,13 @@ def get_value(cls, name):
2382
"""
2483
if name is None:
2584
return None
26-
if any(name == e.value for e in cls):
27-
return name
28-
if any(name == e for e in cls):
29-
return name.value
30-
raise ValueError("'{}' is not a valid choice for enumeration {}".format(name, cls))
85+
return cls.from_str(name, exception=True).value
3186

3287
@classmethod
33-
def get_enum(cls, name):
88+
@deprecated(deprecated_in="1.15.0",
89+
removed_in="2.0.0",
90+
details="Use from_str for retreiving the correct Enum object.")
91+
def get_enum(cls, name: str) -> "BaseEnumeration":
3492
"""
3593
Return the enumeration associated with name.
3694
@@ -39,8 +97,8 @@ def get_enum(cls, name):
3997
"""
4098
if name is None:
4199
return None
42-
if any(name == e.value for e in cls):
43-
return next(e for e in cls if e.value == name)
44-
if any(name == e for e in cls):
45-
return name
46-
raise ValueError("'{}' is not a valid choice for enumeration {}".format(name, cls))
100+
return cls.from_str(name, exception=True)
101+
102+
def __str__(self):
103+
"""Return the value of the enumeration object."""
104+
return self.value

gemd/json/gemd_encoder.py

-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from uuid import UUID
33

44
from gemd.entity.dict_serializable import DictSerializable
5-
from gemd.enumeration.base_enumeration import BaseEnumeration
65

76

87
class GEMDEncoder(JSONEncoder):
@@ -12,8 +11,6 @@ def default(self, o):
1211
"""Default encoder implementation."""
1312
if isinstance(o, DictSerializable):
1413
return o.as_dict()
15-
elif isinstance(o, BaseEnumeration):
16-
return o.value
1714
elif isinstance(o, UUID):
1815
return str(o)
1916
else:

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
packages.append("")
55

66
setup(name='gemd',
7-
version='1.14.1',
7+
version='1.15.0',
88
python_requires='>=3.7',
99
url='http://github.com/CitrineInformatics/gemd-python',
1010
description="Python binding for Citrine's GEMD data model",

tests/enumeration/test_enumeration.py

+44-10
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,24 @@ class GoodClass(BaseEnumeration):
1313
RED = "Red"
1414
BLUE = "Blue"
1515

16-
assert GoodClass.get_value("Red") == "Red"
17-
assert GoodClass.get_value(GoodClass.BLUE) == "Blue"
18-
assert GoodClass.get_value(None) is None
19-
assert GoodClass.get_enum("Red") == GoodClass.RED
20-
assert GoodClass.get_enum(GoodClass.BLUE) == GoodClass.BLUE
21-
assert GoodClass.get_enum(None) is None
22-
with pytest.raises(ValueError):
23-
GoodClass.get_value("Green")
24-
with pytest.raises(ValueError):
25-
GoodClass.get_enum("Green")
16+
with pytest.deprecated_call():
17+
assert GoodClass.get_value("Red") == "Red"
18+
with pytest.deprecated_call():
19+
assert GoodClass.get_value(GoodClass.BLUE) == "Blue"
20+
with pytest.deprecated_call():
21+
assert GoodClass.get_value(None) is None
22+
with pytest.deprecated_call():
23+
assert GoodClass.get_enum("Red") == GoodClass.RED
24+
with pytest.deprecated_call():
25+
assert GoodClass.get_enum(GoodClass.BLUE) == GoodClass.BLUE
26+
with pytest.deprecated_call():
27+
assert GoodClass.get_enum(None) is None
28+
with pytest.deprecated_call():
29+
with pytest.raises(ValueError):
30+
GoodClass.get_value("Green")
31+
with pytest.deprecated_call():
32+
with pytest.raises(ValueError):
33+
GoodClass.get_enum("Green")
2634

2735

2836
def test_json_serde():
@@ -47,3 +55,29 @@ class BadClass1(BaseEnumeration):
4755
class BadClass2(BaseEnumeration):
4856
FIRST = "one"
4957
SECOND = 2
58+
59+
60+
def test_string_enum():
61+
"""Test that the synonym mechanism works."""
62+
63+
class TestEnum(BaseEnumeration):
64+
ONE = "One", "1"
65+
TWO = "Two", "2"
66+
67+
assert TestEnum.ONE == "One", "Equality failed"
68+
assert str(TestEnum.ONE) == "One", "Equality failed, cast"
69+
assert TestEnum.ONE != "1", "Equality worked for synonym"
70+
assert TestEnum.from_str("One") == TestEnum.ONE, "from_str worked"
71+
assert TestEnum.from_str("ONE") == TestEnum.ONE, "from_str, caps worked"
72+
assert TestEnum.from_str("one") == TestEnum.ONE, "from_str, lower worked"
73+
assert TestEnum.from_str("1") == TestEnum.ONE, "from_str, synonym worked"
74+
assert TestEnum.from_str(None) is None, "from_str, bad returned None"
75+
assert TestEnum.from_str("1.0") is None, "from_str, bad returned None"
76+
with pytest.raises(ValueError, match="valid"):
77+
TestEnum.from_str("1.0", exception=True)
78+
for key in TestEnum.TWO.synonyms:
79+
assert key != TestEnum.TWO, f"Synonym {key} was equal?"
80+
assert TestEnum.from_str(key) == TestEnum.TWO, f"from_str didn't resolve {key}"
81+
assert (
82+
TestEnum.from_str(key.upper()) == TestEnum.TWO
83+
), f"from_str didn't resolve {key.upper()}"

tests/json/test_json.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ def test_enumeration_serde():
112112
"""An enumeration should get serialized as a string."""
113113
condition = Condition(name="A condition", notes=Origin.UNKNOWN)
114114
copy_condition = GEMDJson().copy(condition)
115-
assert copy_condition.notes == Origin.get_value(condition.notes)
115+
assert copy_condition.notes == Origin.UNKNOWN.value
116+
assert not isinstance(copy_condition.notes, Origin)
116117

117118

118119
def test_attribute_serde():

0 commit comments

Comments
 (0)