Skip to content

Commit 0089029

Browse files
committed
Add serialization registry
1 parent 13cef45 commit 0089029

File tree

5 files changed

+254
-61
lines changed

5 files changed

+254
-61
lines changed

README.rst

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ to 10, and all ``websocket.send!`` channels to 20:
171171
If you want to enforce a matching order, use an ``OrderedDict`` as the
172172
argument; channels will then be matched in the order the dict provides them.
173173

174+
.. _encryption
174175
``symmetric_encryption_keys``
175176
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
176177

@@ -237,6 +238,44 @@ And then in your channels consumer, you can implement the handler:
237238
async def redis_disconnect(self, *args):
238239
# Handle disconnect
239240
241+
242+
243+
``serializer_format``
244+
~~~~~~~~~~~~~~~~~~~~~~
245+
By default every message which reach redis is encoded using `msgpack <https://msgpack.org/>`_.
246+
It is also possible to switch to `JSON <http://www.json.org/>`_:
247+
248+
.. code-block:: python
249+
250+
CHANNEL_LAYERS = {
251+
"default": {
252+
"BACKEND": "channels_redis.core.RedisChannelLayer",
253+
"CONFIG": {
254+
"hosts": ["redis://:[email protected]:6379/0"],
255+
"serializer_format": "json",
256+
},
257+
},
258+
}
259+
260+
A new serializer may be registered (or can be overriden) by using ``channels_redis.serializers.registry``,
261+
providing a class which extends ``channels_redis.serializers.BaseMessageSerializer``, implementing ``dumps``
262+
and ``loads`` methods, or which provides ``serialize``/``deserialize`` methods and calling the registration method on registry:
263+
264+
.. code-block:: python
265+
266+
from channels_redis.serializers import registry
267+
268+
class MyFormatSerializer:
269+
def serialize(self, message):
270+
...
271+
def deserialize(self, message):
272+
...
273+
274+
registry.register_serializer('myformat', MyFormatSerializer)
275+
276+
**NOTE**: Serializers also perform the encryption job see *symmetric_encryption_keys*.
277+
278+
240279
Dependencies
241280
------------
242281

channels_redis/core.py

Lines changed: 12 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,15 @@
55
import hashlib
66
import itertools
77
import logging
8-
import random
98
import time
109
import uuid
1110

12-
import msgpack
1311
from redis import asyncio as aioredis
1412

1513
from channels.exceptions import ChannelFull
1614
from channels.layers import BaseChannelLayer
1715

16+
from .serializers import registry
1817
from .utils import (
1918
_close_redis,
2019
_consistent_hash,
@@ -115,6 +114,7 @@ def __init__(
115114
capacity=100,
116115
channel_capacity=None,
117116
symmetric_encryption_keys=None,
117+
serializer_format="msgpack",
118118
):
119119
# Store basic information
120120
self.expiry = expiry
@@ -126,15 +126,23 @@ def __init__(
126126
# Configure the host objects
127127
self.hosts = decode_hosts(hosts)
128128
self.ring_size = len(self.hosts)
129+
# serialization
130+
self._serializer = registry.get_serializer(
131+
serializer_format,
132+
# As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes.
133+
random_prefix_length=12,
134+
expiry=self.expiry,
135+
symmetric_encryption_keys=symmetric_encryption_keys,
136+
)
137+
self.serialize = self._serializer.serialize
138+
self.deserialize = self._serializer.deserialize
129139
# Cached redis connection pools and the event loop they are from
130140
self._layers = {}
131141
# Normal channels choose a host index by cycling through the available hosts
132142
self._receive_index_generator = itertools.cycle(range(len(self.hosts)))
133143
self._send_index_generator = itertools.cycle(range(len(self.hosts)))
134144
# Decide on a unique client prefix to use in ! sections
135145
self.client_prefix = uuid.uuid4().hex
136-
# Set up any encryption objects
137-
self._setup_encryption(symmetric_encryption_keys)
138146
# Number of coroutines trying to receive right now
139147
self.receive_count = 0
140148
# The receive lock
@@ -154,24 +162,6 @@ def __init__(
154162
def create_pool(self, index):
155163
return create_pool(self.hosts[index])
156164

157-
def _setup_encryption(self, symmetric_encryption_keys):
158-
# See if we can do encryption if they asked
159-
if symmetric_encryption_keys:
160-
if isinstance(symmetric_encryption_keys, (str, bytes)):
161-
raise ValueError(
162-
"symmetric_encryption_keys must be a list of possible keys"
163-
)
164-
try:
165-
from cryptography.fernet import MultiFernet
166-
except ImportError:
167-
raise ValueError(
168-
"Cannot run with encryption without 'cryptography' installed."
169-
)
170-
sub_fernets = [self.make_fernet(key) for key in symmetric_encryption_keys]
171-
self.crypter = MultiFernet(sub_fernets)
172-
else:
173-
self.crypter = None
174-
175165
### Channel layer API ###
176166

177167
extensions = ["groups", "flush"]
@@ -650,31 +640,6 @@ def _group_key(self, group):
650640
"""
651641
return f"{self.prefix}:group:{group}".encode("utf8")
652642

653-
### Serialization ###
654-
655-
def serialize(self, message):
656-
"""
657-
Serializes message to a byte string.
658-
"""
659-
value = msgpack.packb(message, use_bin_type=True)
660-
if self.crypter:
661-
value = self.crypter.encrypt(value)
662-
663-
# As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes.
664-
random_prefix = random.getrandbits(8 * 12).to_bytes(12, "big")
665-
return random_prefix + value
666-
667-
def deserialize(self, message):
668-
"""
669-
Deserializes from a byte string.
670-
"""
671-
# Removes the random prefix
672-
message = message[12:]
673-
674-
if self.crypter:
675-
message = self.crypter.decrypt(message, self.expiry + 10)
676-
return msgpack.unpackb(message, raw=False)
677-
678643
### Internal functions ###
679644

680645
def consistent_hash(self, value):

channels_redis/pubsub.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import logging
44
import uuid
55

6-
import msgpack
76
from redis import asyncio as aioredis
87

8+
from .serializers import registry
99
from .utils import (
1010
_close_redis,
1111
_consistent_hash,
@@ -25,10 +25,23 @@ async def _async_proxy(obj, name, *args, **kwargs):
2525

2626

2727
class RedisPubSubChannelLayer:
28-
def __init__(self, *args, **kwargs) -> None:
28+
def __init__(
29+
self,
30+
*args,
31+
symmetric_encryption_keys=None,
32+
serializer_format="msgpack",
33+
**kwargs,
34+
) -> None:
2935
self._args = args
3036
self._kwargs = kwargs
3137
self._layers = {}
38+
# serialization
39+
self._serializer = registry.get_serializer(
40+
serializer_format,
41+
symmetric_encryption_keys=symmetric_encryption_keys,
42+
)
43+
self.serialize = self._serializer.serialize
44+
self.deserialize = self._serializer.deserialize
3245

3346
def __getattr__(self, name):
3447
if name in (
@@ -44,18 +57,6 @@ def __getattr__(self, name):
4457
else:
4558
return getattr(self._get_layer(), name)
4659

47-
def serialize(self, message):
48-
"""
49-
Serializes message to a byte string.
50-
"""
51-
return msgpack.packb(message)
52-
53-
def deserialize(self, message):
54-
"""
55-
Deserializes from a byte string.
56-
"""
57-
return msgpack.unpackb(message)
58-
5960
def _get_layer(self):
6061
loop = asyncio.get_running_loop()
6162

channels_redis/serializers.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import abc
2+
import json
3+
import random
4+
5+
6+
class SerializerDoesNotExist(KeyError):
7+
"""The requested serializer was not found."""
8+
9+
10+
class BaseMessageSerializer(abc.ABC):
11+
def __init__(
12+
self,
13+
symmetric_encryption_keys=None,
14+
random_prefix_length=0,
15+
expiry=None,
16+
):
17+
self.random_prefix_length = random_prefix_length
18+
self.expiry = expiry
19+
# Set up any encryption objects
20+
self._setup_encryption(symmetric_encryption_keys)
21+
22+
def _setup_encryption(self, symmetric_encryption_keys):
23+
# See if we can do encryption if they asked
24+
if symmetric_encryption_keys:
25+
if isinstance(symmetric_encryption_keys, (str, bytes)):
26+
raise ValueError(
27+
"symmetric_encryption_keys must be a list of possible keys"
28+
)
29+
try:
30+
from cryptography.fernet import MultiFernet
31+
except ImportError:
32+
raise ValueError(
33+
"Cannot run with encryption without 'cryptography' installed."
34+
)
35+
sub_fernets = [self.make_fernet(key) for key in symmetric_encryption_keys]
36+
self.crypter = MultiFernet(sub_fernets)
37+
else:
38+
self.crypter = None
39+
40+
@abc.abstractmethod
41+
def dumps(self, message):
42+
raise NotImplementedError
43+
44+
@abc.abstractmethod
45+
def loads(self, message):
46+
raise NotImplementedError
47+
48+
def serialize(self, message):
49+
"""
50+
Serializes message to a byte string.
51+
"""
52+
message = self.dumps(message)
53+
# ensure message is bytes
54+
if isinstance(message, str):
55+
message = message.encode("utf-8")
56+
if self.crypter:
57+
message = self.crypter.encrypt(message)
58+
59+
if self.random_prefix_length > 0:
60+
# provide random prefix
61+
message = (
62+
random.getrandbits(8 * self.random_prefix_length).to_bytes(
63+
self.random_prefix_length, "big"
64+
)
65+
+ message
66+
)
67+
return message
68+
69+
def deserialize(self, message):
70+
"""
71+
Deserializes from a byte string.
72+
"""
73+
if self.random_prefix_length > 0:
74+
# Removes the random prefix
75+
message = message[self.random_prefix_length :] # noqa: E203
76+
77+
if self.crypter:
78+
ttl = self.expiry if self.expiry is None else self.expiry + 10
79+
message = self.crypter.decrypt(message, ttl)
80+
return self.loads(message)
81+
82+
83+
class MissingSerializer(BaseMessageSerializer):
84+
exception = None
85+
86+
def __init__(self, *args, **kwargs):
87+
raise self.exception
88+
89+
90+
class JSONSerializer(BaseMessageSerializer):
91+
dumps = staticmethod(json.dumps)
92+
loads = staticmethod(json.loads)
93+
94+
95+
# code ready for a future in which msgpack may become an optional dependency
96+
try:
97+
import msgpack
98+
except ImportError as exc:
99+
100+
class MsgPackSerializer(MissingSerializer):
101+
exception = exc
102+
103+
else:
104+
105+
class MsgPackSerializer(BaseMessageSerializer):
106+
dumps = staticmethod(msgpack.packb)
107+
loads = staticmethod(msgpack.unpackb)
108+
109+
110+
class SerializersRegistry:
111+
def __init__(self):
112+
self._registry = {}
113+
114+
def register_serializer(self, format, serializer_class):
115+
"""
116+
Register a new serializer for given format
117+
"""
118+
assert isinstance(serializer_class, type) and (
119+
issubclass(serializer_class, BaseMessageSerializer)
120+
or hasattr(serializer_class, "serialize")
121+
and hasattr(serializer_class, "deserialize")
122+
), """
123+
`serializer_class` should be a class which implements `serialize` and `deserialize` method
124+
or a subclass of `channels_redis.serializers.BaseMessageSerializer`
125+
"""
126+
127+
self._registry[format] = serializer_class
128+
129+
def get_serializer(self, format, *args, **kwargs):
130+
try:
131+
serializer_class = self._registry[format]
132+
except KeyError:
133+
raise SerializerDoesNotExist(format)
134+
135+
return serializer_class(*args, **kwargs)
136+
137+
138+
registry = SerializersRegistry()
139+
registry.register_serializer("json", JSONSerializer)
140+
registry.register_serializer("msgpack", MsgPackSerializer)

0 commit comments

Comments
 (0)