Skip to content

Commit 90ac7f0

Browse files
committed
Add serialization registry
1 parent 13cef45 commit 90ac7f0

File tree

6 files changed

+381
-85
lines changed

6 files changed

+381
-85
lines changed

README.rst

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ to 10, and all ``websocket.send!`` channels to 20:
171171
If you want to enforce a matching order, use an ``OrderedDict`` as the
172172
argument; channels will then be matched in the order the dict provides them.
173173

174+
.. _encryption
174175
``symmetric_encryption_keys``
175176
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
176177

@@ -237,6 +238,51 @@ And then in your channels consumer, you can implement the handler:
237238
async def redis_disconnect(self, *args):
238239
# Handle disconnect
239240
241+
242+
243+
``serializer_format``
244+
~~~~~~~~~~~~~~~~~~~~~~
245+
By default every message sent to redis is encoded using `msgpack <https://msgpack.org/>`_ (_currently ``msgpack`` is a mandatory dependency of this package, it may become optional in a future release_).
246+
It is also possible to switch to `JSON <http://www.json.org/>`_:
247+
248+
.. code-block:: python
249+
250+
CHANNEL_LAYERS = {
251+
"default": {
252+
"BACKEND": "channels_redis.core.RedisChannelLayer",
253+
"CONFIG": {
254+
"hosts": ["redis://:[email protected]:6379/0"],
255+
"serializer_format": "json",
256+
},
257+
},
258+
}
259+
260+
261+
Custom serializer can be defined by:
262+
263+
- extending ``channels_redis.serializers.BaseMessageSerializer``, implementing ``dumps`` and ``loads`` methods
264+
- using any class which accepts generic keyword arguments and provides ``serialize``/``deserialize`` methods
265+
266+
Then it may be registerd (or can be overriden) by using ``channels_redis.serializers.registry``:
267+
268+
.. code-block:: python
269+
270+
from channels_redis.serializers import registry
271+
272+
class MyFormatSerializer:
273+
def serialize(self, message):
274+
...
275+
def deserialize(self, message):
276+
...
277+
278+
registry.register_serializer('myformat', MyFormatSerializer)
279+
280+
**NOTE**: the registry allows to override the serializer class used for a specific format without any particular check nor constraint, thus it is recommended to pay attention with order-of-imports when using third-party serializers which may override a built-in format.
281+
282+
283+
Serializers are also responsible for encryption *symmetric_encryption_keys*. When extending ``channels_redis.serializers.BaseMessageSerializer`` encryption is already configured in the base class, unless you override ``serialize``/``deserialize`` methods: in this case you should call ``self.crypter.encrypt`` in serialization and ``self.crypter.decrypt`` in deserialization process. When using full custom serializer expect an optional sequence of keys to be passed via ``symmetric_encryption_keys``.
284+
285+
240286
Dependencies
241287
------------
242288

channels_redis/core.py

Lines changed: 13 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,17 @@
11
import asyncio
2-
import base64
32
import collections
43
import functools
5-
import hashlib
64
import itertools
75
import logging
8-
import random
96
import time
107
import uuid
118

12-
import msgpack
139
from redis import asyncio as aioredis
1410

1511
from channels.exceptions import ChannelFull
1612
from channels.layers import BaseChannelLayer
1713

14+
from .serializers import registry
1815
from .utils import (
1916
_close_redis,
2017
_consistent_hash,
@@ -115,6 +112,8 @@ def __init__(
115112
capacity=100,
116113
channel_capacity=None,
117114
symmetric_encryption_keys=None,
115+
random_prefix_length=12,
116+
serializer_format="msgpack",
118117
):
119118
# Store basic information
120119
self.expiry = expiry
@@ -126,15 +125,21 @@ def __init__(
126125
# Configure the host objects
127126
self.hosts = decode_hosts(hosts)
128127
self.ring_size = len(self.hosts)
128+
# serialization
129+
self._serializer = registry.get_serializer(
130+
serializer_format,
131+
# As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes.
132+
random_prefix_length=random_prefix_length,
133+
expiry=self.expiry,
134+
symmetric_encryption_keys=symmetric_encryption_keys,
135+
)
129136
# Cached redis connection pools and the event loop they are from
130137
self._layers = {}
131138
# Normal channels choose a host index by cycling through the available hosts
132139
self._receive_index_generator = itertools.cycle(range(len(self.hosts)))
133140
self._send_index_generator = itertools.cycle(range(len(self.hosts)))
134141
# Decide on a unique client prefix to use in ! sections
135142
self.client_prefix = uuid.uuid4().hex
136-
# Set up any encryption objects
137-
self._setup_encryption(symmetric_encryption_keys)
138143
# Number of coroutines trying to receive right now
139144
self.receive_count = 0
140145
# The receive lock
@@ -154,24 +159,6 @@ def __init__(
154159
def create_pool(self, index):
155160
return create_pool(self.hosts[index])
156161

157-
def _setup_encryption(self, symmetric_encryption_keys):
158-
# See if we can do encryption if they asked
159-
if symmetric_encryption_keys:
160-
if isinstance(symmetric_encryption_keys, (str, bytes)):
161-
raise ValueError(
162-
"symmetric_encryption_keys must be a list of possible keys"
163-
)
164-
try:
165-
from cryptography.fernet import MultiFernet
166-
except ImportError:
167-
raise ValueError(
168-
"Cannot run with encryption without 'cryptography' installed."
169-
)
170-
sub_fernets = [self.make_fernet(key) for key in symmetric_encryption_keys]
171-
self.crypter = MultiFernet(sub_fernets)
172-
else:
173-
self.crypter = None
174-
175162
### Channel layer API ###
176163

177164
extensions = ["groups", "flush"]
@@ -656,41 +643,19 @@ def serialize(self, message):
656643
"""
657644
Serializes message to a byte string.
658645
"""
659-
value = msgpack.packb(message, use_bin_type=True)
660-
if self.crypter:
661-
value = self.crypter.encrypt(value)
662-
663-
# As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes.
664-
random_prefix = random.getrandbits(8 * 12).to_bytes(12, "big")
665-
return random_prefix + value
646+
return self._serializer.serialize(message)
666647

667648
def deserialize(self, message):
668649
"""
669650
Deserializes from a byte string.
670651
"""
671-
# Removes the random prefix
672-
message = message[12:]
673-
674-
if self.crypter:
675-
message = self.crypter.decrypt(message, self.expiry + 10)
676-
return msgpack.unpackb(message, raw=False)
652+
return self._serializer.deserialize(message)
677653

678654
### Internal functions ###
679655

680656
def consistent_hash(self, value):
681657
return _consistent_hash(value, self.ring_size)
682658

683-
def make_fernet(self, key):
684-
"""
685-
Given a single encryption key, returns a Fernet instance using it.
686-
"""
687-
from cryptography.fernet import Fernet
688-
689-
if isinstance(key, str):
690-
key = key.encode("utf8")
691-
formatted_key = base64.urlsafe_b64encode(hashlib.sha256(key).digest())
692-
return Fernet(formatted_key)
693-
694659
def __str__(self):
695660
return f"{self.__class__.__name__}(hosts={self.hosts})"
696661

channels_redis/pubsub.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import logging
44
import uuid
55

6-
import msgpack
76
from redis import asyncio as aioredis
87

8+
from .serializers import registry
99
from .utils import (
1010
_close_redis,
1111
_consistent_hash,
@@ -25,10 +25,33 @@ async def _async_proxy(obj, name, *args, **kwargs):
2525

2626

2727
class RedisPubSubChannelLayer:
28-
def __init__(self, *args, **kwargs) -> None:
28+
def __init__(
29+
self,
30+
*args,
31+
symmetric_encryption_keys=None,
32+
serializer_format="msgpack",
33+
**kwargs,
34+
) -> None:
2935
self._args = args
3036
self._kwargs = kwargs
3137
self._layers = {}
38+
# serialization
39+
self._serializer = registry.get_serializer(
40+
serializer_format,
41+
symmetric_encryption_keys=symmetric_encryption_keys,
42+
)
43+
44+
def serialize(self, message):
45+
"""
46+
Serializes message to a byte string.
47+
"""
48+
return self._serializer.serialize(message)
49+
50+
def deserialize(self, message):
51+
"""
52+
Deserializes from a byte string.
53+
"""
54+
return self._serializer.deserialize(message)
3255

3356
def __getattr__(self, name):
3457
if name in (
@@ -44,18 +67,6 @@ def __getattr__(self, name):
4467
else:
4568
return getattr(self._get_layer(), name)
4669

47-
def serialize(self, message):
48-
"""
49-
Serializes message to a byte string.
50-
"""
51-
return msgpack.packb(message)
52-
53-
def deserialize(self, message):
54-
"""
55-
Deserializes from a byte string.
56-
"""
57-
return msgpack.unpackb(message)
58-
5970
def _get_layer(self):
6071
loop = asyncio.get_running_loop()
6172

0 commit comments

Comments
 (0)