From cda68c7e6141cdac1bb66165accd04ad30181d90 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Wed, 14 Feb 2018 09:35:02 +0100 Subject: [PATCH] Add testcase for channel.new_consumer() contextmanager+iterator --- tests/test_properties.py | 151 +++++++++++++++++++++++++++++++++++++-- trio_amqp/channel.py | 30 ++++---- 2 files changed, 158 insertions(+), 23 deletions(-) diff --git a/tests/test_properties.py b/tests/test_properties.py index 1be9973..63e6114 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -10,8 +10,13 @@ logger = logging.getLogger(__name__) +server_queue_name = 'server_queue' +client_queue_name = 'client_reply_queue' +exchange_name = 'exchange_name' +server_routing_key = 'reply_test' -class TestReply(testcase.RabbitTestCase): +class TestReplyOld(testcase.RabbitTestCase): + """RPC test using classic callbacks""" async def _server( self, amqp, @@ -23,7 +28,6 @@ async def _server( """Consume messages and reply to them by publishing messages back to the client using routing key set to the reply_to property """ - server_queue_name = 'server_queue' async with amqp.new_channel() as channel: await channel.queue_declare( server_queue_name, exclusive=False, no_wait=False @@ -50,9 +54,9 @@ async def server_callback(channel, body, envelope, properties): await channel.basic_consume( server_callback, queue_name=server_queue_name ) + logger.debug('Server consuming messages') task_status.started() await server_future.wait() - logger.debug('Server consuming messages') async def _client( self, @@ -68,7 +72,6 @@ async def _client( message to the server with the reply_to property set to that routing key """ - client_queue_name = 'client_reply_queue' async with amqp.new_channel() as client_channel: await client_channel.queue_declare( client_queue_name, exclusive=True, no_wait=False @@ -101,9 +104,144 @@ async def client_callback(channel, body, envelope, properties): @pytest.mark.trio async def test_reply_to(self, amqp): - exchange_name = 'exchange_name' - server_routing_key = 'reply_test' + server_future = trio.Event() + async with trio.open_nursery() as n: + await n.start( + self._server, amqp, server_future, exchange_name, + server_routing_key + ) + + correlation_id = 'secret correlation id' + client_routing_key = 'secret_client_key' + + client_future = trio.Event() + await n.start( + self._client, amqp, client_future, exchange_name, + server_routing_key, correlation_id, client_routing_key + ) + + logger.debug('Waiting for server to receive message') + await server_future.wait() + server_body, server_envelope, server_properties = \ + server_future.test_result + assert server_body == b'client message' + assert server_properties.correlation_id == correlation_id + assert server_properties.reply_to == client_routing_key + assert server_envelope.routing_key == server_routing_key + + logger.debug('Waiting for client to receive message') + await client_future.wait() + client_body, client_envelope, client_properties = \ + client_future.test_result + assert client_body == b'reply message' + assert client_properties.correlation_id == correlation_id + assert client_envelope.routing_key == client_routing_key + n.cancel_scope.cancel() + +class TestReplyNew(testcase.RabbitTestCase): + """RPC test using iteration""" + async def _server( + self, + amqp, + server_future, + exchange_name, + routing_key, + task_status=trio.TASK_STATUS_IGNORED + ): + """Consume messages and reply to them by publishing messages back + to the client using routing key set to the reply_to property + """ + async with amqp.new_channel() as channel: + await channel.queue_declare( + server_queue_name, exclusive=False, no_wait=False + ) + await channel.exchange_declare(exchange_name, type_name='direct') + await channel.queue_bind( + server_queue_name, exchange_name, routing_key=routing_key + ) + + async with trio.open_nursery() as n: + await n.start(self._server_consumer, channel, server_future) + task_status.started() + await server_future.wait() + self._server_scope.cancel() + + async def _server_consumer(self, channel, server_future, task_status=trio.TASK_STATUS_IGNORED): + with trio.open_cancel_scope() as scope: + self._server_scope = scope + async with channel.new_consumer(queue_name=server_queue_name) \ + as data: + logger.debug('Server consuming messages') + task_status.started() + async for body, envelope, properties in data: + + logger.debug('Server received message') + publish_properties = { + 'correlation_id': properties.correlation_id + } + logger.debug('Replying to %r', properties.reply_to) + await channel.publish( + b'reply message', exchange_name, properties.reply_to, + publish_properties + ) + server_future.test_result = (body, envelope, properties) + server_future.set() + logger.debug('Server replied') + + async def _client( + self, + amqp, + client_future, + exchange_name, + server_routing_key, + correlation_id, + client_routing_key, + task_status=trio.TASK_STATUS_IGNORED + ): + """Declare a queue, bind client_routing_key to it, and publish a + message to the server with the reply_to property set to that + routing key + """ + async with amqp.new_channel() as client_channel: + await client_channel.queue_declare( + client_queue_name, exclusive=True, no_wait=False + ) + await client_channel.queue_bind( + client_queue_name, + exchange_name, + routing_key=client_routing_key + ) + + async with trio.open_nursery() as n: + await n.start(self._client_consumer, client_channel, client_future) + task_status.started() + + await client_channel.publish( + b'client message', exchange_name, server_routing_key, { + 'correlation_id': correlation_id, + 'reply_to': client_routing_key + } + ) + logger.debug('Client published message') + await client_future.wait() + self._client_scope.cancel() + + async def _client_consumer(self, channel, client_future, task_status=trio.TASK_STATUS_IGNORED): + with trio.open_cancel_scope() as scope: + self._client_scope = scope + async with channel.new_consumer(queue_name=client_queue_name) \ + as data: + task_status.started() + logger.debug('Client consuming messages') + + async for body, envelope, properties in data: + logger.debug('Client received message') + client_future.test_result = (body, envelope, properties) + client_future.set() + + @pytest.mark.trio + async def test_reply_to(self, amqp): server_future = trio.Event() async with trio.open_nursery() as n: await n.start( @@ -137,3 +275,4 @@ async def test_reply_to(self, amqp): assert client_properties.correlation_id == correlation_id assert client_envelope.routing_key == client_routing_key n.cancel_scope.cancel() + diff --git a/trio_amqp/channel.py b/trio_amqp/channel.py index 6ef2b4f..6090bdb 100644 --- a/trio_amqp/channel.py +++ b/trio_amqp/channel.py @@ -27,9 +27,8 @@ def __init__(self, channel, consumer_tag, **kwargs): self.channel = channel self.kwargs = kwargs self.consumer_tag = consumer_tag - self._q = trio.Queue(30) # TODO: 2 + possible prefetch - async def __call__(self, msg, env, prop): + async def _data(self, channel, msg, env, prop): await self._q.put((msg, env, prop)) async def __aiter__(self): @@ -40,9 +39,9 @@ async def __anext__(self): async def __aenter__(self): await self.channel.basic_consume( - self, consumer_tag=self.consumer_tag, **self.kwargs + self._data, consumer_tag=self.consumer_tag, **self.kwargs ) - self._q = trio.Queue() + self._q = trio.Queue(30) # TODO: 2 + possible prefetch return self async def __aexit__(self, *tb): @@ -52,6 +51,15 @@ async def __aexit__(self, *tb): # these messages are not acknowledged, thus deleting the queue will # not lose them + def __enter__(self): + raise RuntimeError("You need to use 'async with'.") + + def __exit__(self, *tb): + raise RuntimeError("You need to use 'async with'.") + + def __iter__(self): + raise RuntimeError("You need to use 'async for'.") + class Channel: def __init__(self, protocol, channel_id): @@ -840,7 +848,7 @@ def new_consumer( async def basic_consume( self, - callback=None, + callback, queue_name='', consumer_tag='', no_local=False, @@ -898,18 +906,6 @@ async def basic_consume( if arguments is None: arguments = {} - if callback is None: - return BasicListener( - self, - queue_name=queue_name, - consumer_tag=consumer_tag, - no_local=no_local, - no_ack=no_ack, - exclusive=exclusive, - no_wait=no_wait, - arguments=arguments - ) - frame = amqp_frame.AmqpRequest( amqp_constants.TYPE_METHOD, self.channel_id )