Skip to content

Commit acb576b

Browse files
committed
Improved the event descriptor to handle the case where slots are defined on parent object (eg an ABC)
1 parent fdc6cf0 commit acb576b

File tree

2 files changed

+90
-55
lines changed

2 files changed

+90
-55
lines changed

src/pyapp/events.py

Lines changed: 66 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,53 @@ async def on_new_message(message: str):
6464
_F = TypeVar("_F", bound=Callable[..., Any])
6565

6666

67+
class _ListenerDescriptor:
68+
"""Common base descriptor class."""
69+
70+
__slots__ = ("name",)
71+
72+
set_type: type
73+
name: Optional[str]
74+
75+
def __init__(self):
76+
"""Initialise descriptor."""
77+
self.name = None
78+
79+
def __set_name__(self, owner: type, name: str):
80+
"""Assign a name to the instance."""
81+
if self.name is None:
82+
self.name = name
83+
elif name != self.name:
84+
raise TypeError(
85+
"Cannot assign the same event to two different names "
86+
f"{self.name!r} and {name!r}"
87+
)
88+
89+
def get_listeners(self, instance):
90+
"""Get listeners from instance."""
91+
if self.name is None:
92+
raise TypeError(
93+
f"Cannot use {type(self).__name__!r} instance without "
94+
f"calling __set_name__ on it."
95+
)
96+
try:
97+
return instance.__dict__[self.name]
98+
except AttributeError:
99+
msg = (
100+
f"No '__dict__' attribute on {type(instance).__name__!r} "
101+
f"instance to store {self.name!r} events."
102+
)
103+
raise TypeError(msg) from None
104+
except KeyError:
105+
return None
106+
107+
def set_listeners(self, instance, listeners):
108+
"""Store listeners on instance."""
109+
# Called by subclasses, all access checks are performed in the get method
110+
instance.__dict__[self.name] = listeners
111+
return listeners
112+
113+
67114
class ListenerContext(Generic[_CT]):
68115
"""Context manager to manage temporary listeners.
69116
@@ -131,28 +178,19 @@ def __call__(self, *args, **kwargs):
131178
callback(*args, **kwargs)
132179

133180

134-
class Event(Generic[_CT]):
181+
class Event(Generic[_CT], _ListenerDescriptor):
135182
"""Event publisher descriptor.
136183
137184
Used to gain access to the listener list.
138185
139186
"""
140187

141-
__slots__ = ("name",)
188+
__slots__ = ()
142189

143190
def __get__(self, instance, owner) -> ListenerSet[_CT]:
144-
try:
145-
return instance.__dict__[self.name]
146-
except KeyError:
147-
instance.__dict__[self.name] = listeners = ListenerSet()
191+
if listeners := self.get_listeners(instance):
148192
return listeners
149-
150-
def __set_name__(self, owner, name):
151-
if hasattr(owner, "__slots__"):
152-
raise UnsupportedObject(
153-
"An Event cannot be used on an object with __slots__ defined."
154-
)
155-
self.name = name # pylint: disable=attribute-defined-outside-init
193+
return self.set_listeners(instance, ListenerSet())
156194

157195

158196
_ACT = TypeVar("_ACT", bound=Union[Callable[..., Coroutine], "AsyncListenerList"])
@@ -172,24 +210,19 @@ async def __call__(self, *args, **kwargs):
172210
await asyncio.wait(awaitables, return_when=asyncio.ALL_COMPLETED)
173211

174212

175-
class AsyncEvent(Generic[_ACT]):
213+
class AsyncEvent(Generic[_ACT], _ListenerDescriptor):
176214
"""Async event publisher descriptor.
177215
178216
Used to gain access to the listener list.
179217
180218
"""
181219

182-
__slots__ = ("name",)
220+
__slots__ = ()
183221

184222
def __get__(self, instance, owner) -> ListenerSet[_ACT]:
185-
try:
186-
return instance.__dict__[self.name]
187-
except KeyError:
188-
instance.__dict__[self.name] = listeners = AsyncListenerSet()
223+
if listeners := self.get_listeners(instance):
189224
return listeners
190-
191-
def __set_name__(self, owner, name):
192-
self.name = name # pylint: disable=attribute-defined-outside-init
225+
return self.set_listeners(instance, AsyncListenerSet())
193226

194227

195228
def listen_to(event: ListenerSet[_F]) -> Callable[[_F], _F]:
@@ -230,31 +263,22 @@ class CallbackBinding(CallbackBindingBase[_ACT]):
230263

231264
def __call__(self, *args, **kwargs):
232265
if self._callback:
233-
self._callback(*args, **kwargs)
266+
return self._callback(*args, **kwargs)
234267

235268

236-
class Callback(Generic[_CT]):
269+
class Callback(Generic[_CT], _ListenerDescriptor):
237270
"""Callback descriptor.
238271
239272
Used to attach a single callback.
240273
241274
"""
242275

243-
__slots__ = ("name",)
276+
__slots__ = ()
244277

245278
def __get__(self, instance, owner) -> CallbackBinding[_CT]:
246-
try:
247-
return instance.__dict__[self.name]
248-
except KeyError:
249-
instance.__dict__[self.name] = wrapper = CallbackBinding[_CT]()
250-
return wrapper
251-
252-
def __set_name__(self, owner, name):
253-
if hasattr(owner, "__slots__"):
254-
raise UnsupportedObject(
255-
"A Callback cannot be used on an object with __slots__ defined."
256-
)
257-
self.name = name # pylint: disable=attribute-defined-outside-init
279+
if listeners := self.get_listeners(instance):
280+
return listeners
281+
return self.set_listeners(instance, CallbackBinding())
258282

259283

260284
class AsyncCallbackBinding(CallbackBindingBase[_ACT]):
@@ -267,24 +291,19 @@ async def __call__(self, *args, **kwargs):
267291
return await self._callback(*args, **kwargs)
268292

269293

270-
class AsyncCallback(Generic[_ACT]):
294+
class AsyncCallback(Generic[_ACT], _ListenerDescriptor):
271295
"""Async callback descriptor.
272296
273297
Use to attach a single async callback
274298
275299
"""
276300

277-
__slots__ = ("name",)
301+
__slots__ = ()
278302

279303
def __get__(self, instance, owner) -> AsyncCallbackBinding[_ACT]:
280-
try:
281-
return instance.__dict__[self.name]
282-
except KeyError:
283-
instance.__dict__[self.name] = wrapper = AsyncCallbackBinding[_ACT]()
284-
return wrapper
285-
286-
def __set_name__(self, owner, name):
287-
self.name = name # pylint: disable=attribute-defined-outside-init
304+
if listeners := self.get_listeners(instance):
305+
return listeners
306+
return self.set_listeners(instance, AsyncCallbackBinding())
288307

289308

290309
def bind_to(callback: CallbackBinding[_F]) -> Callable[[_F], _F]:

tests/unit/test_events.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,13 @@ def on_target():
103103
assert len(instance.target) == 1
104104

105105
def test__when_object_has_slots(self):
106-
with pytest.raises(RuntimeError):
106+
class MyObject:
107+
__slots__ = ("foo",)
107108

108-
class MyObject:
109-
__slots__ = ("foo",)
109+
target = events.Event[Callable[[], None]]()
110110

111-
target = events.Event[Callable[[], None]]()
111+
with pytest.raises(TypeError):
112+
str(MyObject().target)
112113

113114

114115
class TestAsyncListenerSet:
@@ -205,13 +206,28 @@ def on_target():
205206
target = instance.target
206207
assert target._callback is on_target
207208

209+
def test_call(self):
210+
class MyObject:
211+
target = events.Callback[Callable[[], str]]()
212+
213+
instance = MyObject()
214+
215+
@events.bind_to(instance.target)
216+
def on_target():
217+
return "foo"
218+
219+
actual = instance.target()
220+
221+
assert actual == "foo"
222+
208223
def test__when_object_has_slots(self):
209-
with pytest.raises(RuntimeError):
224+
class MyObject:
225+
__slots__ = ("foo",)
210226

211-
class MyObject:
212-
__slots__ = ("foo",)
227+
target = events.Callback[Callable[[], None]]()
213228

214-
target = events.Callback[Callable[[], None]]()
229+
with pytest.raises(TypeError):
230+
str(MyObject().target)
215231

216232

217233
class TestAsyncCallbackBinding:

0 commit comments

Comments
 (0)