Skip to content

Commit

Permalink
Add group_send_multiple method to core layer
Browse files Browse the repository at this point in the history
  • Loading branch information
olzhasar committed Oct 8, 2024
1 parent 4cb9b90 commit 5fa875f
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 17 deletions.
50 changes: 33 additions & 17 deletions channels_redis/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,13 @@ async def group_send(self, group, message):
"""
Sends a message to the entire group.
"""
await self.group_send_multiple(group, (message,))

async def group_send_multiple(self, group, messages):
"""
Sends a message to the entire group.
"""

assert self.valid_group_name(group), "Group name not valid"
# Retrieve list of all channel names
key = self._group_key(group)
Expand All @@ -553,7 +560,7 @@ async def group_send(self, group, message):
connection_to_channel_keys,
channel_keys_to_message,
channel_keys_to_capacity,
) = self._map_channel_keys_to_connection(channel_names, message)
) = self._map_channel_keys_to_connection(channel_names, messages)

for connection_index, channel_redis_keys in connection_to_channel_keys.items():
# Discard old messages based on expiry
Expand All @@ -565,17 +572,23 @@ async def group_send(self, group, message):
await pipe.execute()

# Create a LUA script specific for this connection.
# Make sure to use the message specific to this channel, it is
# stored in channel_to_message dict and contains the
# Make sure to use the message list specific to this channel, it is
# stored in channel_to_message dict and each message contains the
# __asgi_channel__ key.

group_send_lua = """
local over_capacity = 0
local num_messages = tonumber(ARGV[#ARGV - 2])
local current_time = ARGV[#ARGV - 1]
local expiry = ARGV[#ARGV]
for i=1,#KEYS do
if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS]) then
redis.call('ZADD', KEYS[i], current_time, ARGV[i])
if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS * num_messages]) then
local messages = {}
for j=num_messages * (i - 1) + 1, num_messages * i do
table.insert(messages, current_time)
table.insert(messages, ARGV[j])
end
redis.call('ZADD', KEYS[i], unpack(messages))
redis.call('EXPIRE', KEYS[i], expiry)
else
over_capacity = over_capacity + 1
Expand All @@ -585,18 +598,18 @@ async def group_send(self, group, message):
"""

# We need to filter the messages to keep those related to the connection
args = [
channel_keys_to_message[channel_key]
for channel_key in channel_redis_keys
]
args = []

for channel_key in channel_redis_keys:
args += channel_keys_to_message[channel_key]

# We need to send the capacity for each channel
args += [
channel_keys_to_capacity[channel_key]
for channel_key in channel_redis_keys
]

args += [time.time(), self.expiry]
args += [len(messages), time.time(), self.expiry]

# channel_keys does not contain a single redis key more than once
connection = self.connection(connection_index)
Expand All @@ -611,7 +624,7 @@ async def group_send(self, group, message):
group,
)

def _map_channel_keys_to_connection(self, channel_names, message):
def _map_channel_keys_to_connection(self, channel_names, messages):
"""
For a list of channel names, GET
Expand All @@ -626,7 +639,7 @@ def _map_channel_keys_to_connection(self, channel_names, message):
# Connection dict keyed by index to list of redis keys mapped on that index
connection_to_channel_keys = collections.defaultdict(list)
# Message dict maps redis key to the message that needs to be send on that key
channel_key_to_message = dict()
channel_key_to_message = collections.defaultdict(list)
# Channel key mapped to its capacity
channel_key_to_capacity = dict()

Expand All @@ -640,20 +653,23 @@ def _map_channel_keys_to_connection(self, channel_names, message):
# Have we come across the same redis key?
if channel_key not in channel_key_to_message:
# If not, fill the corresponding dicts
message = dict(message.items())
message["__asgi_channel__"] = [channel]
channel_key_to_message[channel_key] = message
for message in messages:
message = dict(message.items())
message["__asgi_channel__"] = [channel]
channel_key_to_message[channel_key].append(message)
channel_key_to_capacity[channel_key] = self.get_capacity(channel)
idx = self.consistent_hash(channel_non_local_name)
connection_to_channel_keys[idx].append(channel_key)
else:
# Yes, Append the channel in message dict
channel_key_to_message[channel_key]["__asgi_channel__"].append(channel)
for message in channel_key_to_message[channel_key]:
message["__asgi_channel__"].append(channel)

# Now that we know what message needs to be send on a redis key we serialize it
for key, value in channel_key_to_message.items():
# Serialize the message stored for each redis key
channel_key_to_message[key] = self.serialize(value)
for idx, message in enumerate(value):
channel_key_to_message[key][idx] = self.serialize(message)

return (
connection_to_channel_keys,
Expand Down
35 changes: 35 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import collections
import random

import async_timeout
Expand Down Expand Up @@ -244,6 +245,40 @@ async def test_groups_basic(channel_layer):
await channel_layer.flush()


@pytest.mark.asyncio
async def test_groups_multiple(channel_layer):
"""
Tests basic group operation.
"""
channel_name1 = await channel_layer.new_channel(prefix="test-gr-chan-1")
channel_name2 = await channel_layer.new_channel(prefix="test-gr-chan-2")
channel_name3 = await channel_layer.new_channel(prefix="test-gr-chan-3")
await channel_layer.group_add("test-group", channel_name1)
await channel_layer.group_add("test-group", channel_name2)
await channel_layer.group_add("test-group", channel_name3)

messages = [
{"type": "message.1"},
{"type": "message.2"},
{"type": "message.3"},
]

expected = {msg["type"] for msg in messages}

await channel_layer.group_send_multiple("test-group", messages)

received = collections.defaultdict(set)

for channel_name in (channel_name1, channel_name2, channel_name3):
async with async_timeout.timeout(1):
for _ in range(len(messages)):
received[channel_name].add(
(await channel_layer.receive(channel_name))["type"]
)

assert received[channel_name] == expected


@pytest.mark.asyncio
async def test_groups_channel_full(channel_layer):
"""
Expand Down

0 comments on commit 5fa875f

Please sign in to comment.