diff --git a/.travis.yml b/.travis.yml index 07a7c9a..bd45722 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,10 +1,10 @@ language: python -dist: bionic +dist: focal python: -- 3.5 - 3.6 - 3.7 - 3.8 +- 3.9 services: - rabbitmq install: diff --git a/Dockerfile b/Dockerfile index 7ec4545..1ea92f1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.5 +FROM python:3.9 WORKDIR /usr/src/app diff --git a/aioamqp/__init__.py b/aioamqp/__init__.py new file mode 100644 index 0000000..030e7db --- /dev/null +++ b/aioamqp/__init__.py @@ -0,0 +1,103 @@ +import asyncio +import socket +from urllib.parse import urlparse + +from .exceptions import * # pylint: disable=wildcard-import +from .protocol import AmqpProtocol + +from .version import __version__ +from .version import __packagename__ + + +async def connect(host='localhost', port=None, login='guest', password='guest', + virtualhost='/', ssl=None, login_method='PLAIN', insist=False, + protocol_factory=AmqpProtocol, *, loop=None, **kwargs): + """Convenient method to connect to an AMQP broker + + @host: the host to connect to + @port: broker port + @login: login + @password: password + @virtualhost: AMQP virtualhost to use for this connection + @ssl: SSL context used for secure connections, omit for no SSL + - see https://docs.python.org/3/library/ssl.html + @login_method: AMQP auth method + @insist: Insist on connecting to a server + @protocol_factory: + Factory to use, if you need to subclass AmqpProtocol + @loop: Set the event loop to use + + @kwargs: Arguments to be given to the protocol_factory instance + + Returns: a tuple (transport, protocol) of an AmqpProtocol instance + """ + if loop is None: + loop = asyncio.get_event_loop() + factory = lambda: protocol_factory(loop=loop, **kwargs) + + create_connection_kwargs = {} + + if ssl is not None: + create_connection_kwargs['ssl'] = ssl + + if port is None: + if ssl: + port = 5671 + else: + port = 5672 + + transport, protocol = await loop.create_connection( + factory, host, port, **create_connection_kwargs + ) + + # these 2 flags *may* show up in sock.type. They are only available on linux + # see https://bugs.python.org/issue21327 + nonblock = getattr(socket, 'SOCK_NONBLOCK', 0) + cloexec = getattr(socket, 'SOCK_CLOEXEC', 0) + sock = transport.get_extra_info('socket') + if sock is not None and (sock.type & ~nonblock & ~cloexec) == socket.SOCK_STREAM: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + try: + await protocol.start_connection(host, port, login, password, virtualhost, ssl=ssl, login_method=login_method, + insist=insist) + except Exception: + await protocol.wait_closed() + raise + + return transport, protocol + + +async def from_url( + url, login_method='PLAIN', insist=False, protocol_factory=AmqpProtocol, **kwargs): + """ Connect to the AMQP using a single url parameter and return the client. + + For instance: + + amqp://user:password@hostname:port/vhost + + @insist: Insist on connecting to a server + @protocol_factory: + Factory to use, if you need to subclass AmqpProtocol + @loop: optionally set the event loop to use. + + @kwargs: Arguments to be given to the protocol_factory instance + + Returns: a tuple (transport, protocol) of an AmqpProtocol instance + """ + url = urlparse(url) + + if url.scheme not in ('amqp', 'amqps'): + raise ValueError('Invalid protocol %s, valid protocols are amqp or amqps' % url.scheme) + + transport, protocol = await connect( + host=url.hostname or 'localhost', + port=url.port, + login=url.username or 'guest', + password=url.password or 'guest', + virtualhost=(url.path[1:] if len(url.path) > 1 else '/'), + login_method=login_method, + insist=insist, + protocol_factory=protocol_factory, + **kwargs) + return transport, protocol diff --git a/aioamqp/channel.py b/aioamqp/channel.py new file mode 100644 index 0000000..b58bd5a --- /dev/null +++ b/aioamqp/channel.py @@ -0,0 +1,736 @@ +""" + Amqp channel specification +""" + +import asyncio +import logging +import uuid +import io +from itertools import count +import warnings + +import pamqp.specification + +from . import frame as amqp_frame +from . import exceptions +from . import properties as amqp_properties +from .envelope import Envelope, ReturnEnvelope + + +logger = logging.getLogger(__name__) + + +class Channel: + + def __init__(self, protocol, channel_id, return_callback=None): + self._loop = protocol._loop + self.protocol = protocol + self.channel_id = channel_id + self.consumer_queues = {} + self.consumer_callbacks = {} + self.cancellation_callbacks = [] + self.return_callback = return_callback + self.response_future = None + self.close_event = asyncio.Event() + self.cancelled_consumers = set() + self.last_consumer_tag = None + self.publisher_confirms = False + self.delivery_tag_iter = None # used for mapping delivered messages to publisher confirms + + self._exchange_declare_lock = asyncio.Lock() + self._queue_bind_lock = asyncio.Lock() + self._futures = {} + self._ctag_events = {} + + def _set_waiter(self, rpc_name): + if rpc_name in self._futures: + raise exceptions.SynchronizationError("Waiter already exists") + + fut = asyncio.Future(loop=self._loop) + self._futures[rpc_name] = fut + return fut + + def _get_waiter(self, rpc_name): + fut = self._futures.pop(rpc_name, None) + if not fut: + raise exceptions.SynchronizationError("Call %s didn't set a waiter" % rpc_name) + return fut + + @property + def is_open(self): + return not self.close_event.is_set() + + def connection_closed(self, server_code=None, server_reason=None, exception=None): + for future in self._futures.values(): + if future.done(): + continue + if exception is None: + kwargs = {} + if server_code is not None: + kwargs['code'] = server_code + if server_reason is not None: + kwargs['message'] = server_reason + exception = exceptions.ChannelClosed(**kwargs) + future.set_exception(exception) + + self.protocol.release_channel_id(self.channel_id) + self.close_event.set() + + async def dispatch_frame(self, frame): + methods = { + pamqp.specification.Channel.OpenOk.name: self.open_ok, + pamqp.specification.Channel.FlowOk.name: self.flow_ok, + pamqp.specification.Channel.CloseOk.name: self.close_ok, + pamqp.specification.Channel.Close.name: self.server_channel_close, + + pamqp.specification.Exchange.DeclareOk.name: self.exchange_declare_ok, + pamqp.specification.Exchange.BindOk.name: self.exchange_bind_ok, + pamqp.specification.Exchange.UnbindOk.name: self.exchange_unbind_ok, + pamqp.specification.Exchange.DeleteOk.name: self.exchange_delete_ok, + + pamqp.specification.Queue.DeclareOk.name: self.queue_declare_ok, + pamqp.specification.Queue.DeleteOk.name: self.queue_delete_ok, + pamqp.specification.Queue.BindOk.name: self.queue_bind_ok, + pamqp.specification.Queue.UnbindOk.name: self.queue_unbind_ok, + pamqp.specification.Queue.PurgeOk.name: self.queue_purge_ok, + + pamqp.specification.Basic.QosOk.name: self.basic_qos_ok, + pamqp.specification.Basic.ConsumeOk.name: self.basic_consume_ok, + pamqp.specification.Basic.CancelOk.name: self.basic_cancel_ok, + pamqp.specification.Basic.GetOk.name: self.basic_get_ok, + pamqp.specification.Basic.GetEmpty.name: self.basic_get_empty, + pamqp.specification.Basic.Deliver.name: self.basic_deliver, + pamqp.specification.Basic.Cancel.name: self.server_basic_cancel, + pamqp.specification.Basic.Ack.name: self.basic_server_ack, + pamqp.specification.Basic.Nack.name: self.basic_server_nack, + pamqp.specification.Basic.RecoverOk.name: self.basic_recover_ok, + pamqp.specification.Basic.Return.name: self.basic_return, + + pamqp.specification.Confirm.SelectOk.name: self.confirm_select_ok, + } + + if frame.name not in methods: + raise NotImplementedError("Frame %s is not implemented" % frame.name) + + await methods[frame.name](frame) + + async def _write_frame(self, channel_id, request, check_open=True, drain=True): + await self.protocol.ensure_open() + if not self.is_open and check_open: + raise exceptions.ChannelClosed() + amqp_frame.write(self.protocol._stream_writer, channel_id, request) + if drain: + await self.protocol._drain() + + async def _write_frame_awaiting_response(self, waiter_id, channel_id, request, + no_wait, check_open=True, drain=True): + '''Write a frame and set a waiter for the response (unless no_wait is set)''' + if no_wait: + await self._write_frame(channel_id, request, check_open=check_open, drain=drain) + return None + + f = self._set_waiter(waiter_id) + try: + await self._write_frame(channel_id, request, check_open=check_open, drain=drain) + except Exception: + self._get_waiter(waiter_id) + f.cancel() + raise + return (await f) + +# +## Channel class implementation +# + + async def open(self): + """Open the channel on the server.""" + request = pamqp.specification.Channel.Open() + return (await self._write_frame_awaiting_response( + 'open', self.channel_id, request, no_wait=False, check_open=False)) + + async def open_ok(self, frame): + self.close_event.clear() + fut = self._get_waiter('open') + fut.set_result(True) + logger.debug("Channel is open") + + async def close(self, reply_code=0, reply_text="Normal Shutdown"): + """Close the channel.""" + if not self.is_open: + raise exceptions.ChannelClosed("channel already closed or closing") + self.close_event.set() + request = pamqp.specification.Channel.Close(reply_code, reply_text, class_id=0, method_id=0) + return (await self._write_frame_awaiting_response( + 'close', self.channel_id, request, no_wait=False, check_open=False)) + + async def close_ok(self, frame): + self._get_waiter('close').set_result(True) + logger.info("Channel closed") + self.protocol.release_channel_id(self.channel_id) + + async def _send_channel_close_ok(self): + request = pamqp.specification.Channel.CloseOk() + await self._write_frame(self.channel_id, request) + + async def server_channel_close(self, frame): + await self._send_channel_close_ok() + results = { + 'reply_code': frame.reply_code, + 'reply_text': frame.reply_text, + 'class_id': frame.class_id, + 'method_id': frame.method_id, + } + self.connection_closed(results['reply_code'], results['reply_text']) + + async def flow(self, active): + request = pamqp.specification.Channel.Flow(active) + return (await self._write_frame_awaiting_response( + 'flow', self.channel_id, request, no_wait=False, + check_open=False)) + + async def flow_ok(self, frame): + self.close_event.clear() + fut = self._get_waiter('flow') + fut.set_result({'active': frame.active}) + + logger.debug("Flow ok") + +# +## Exchange class implementation +# + + async def exchange_declare(self, exchange_name, type_name, passive=False, durable=False, + auto_delete=False, no_wait=False, arguments=None): + request = pamqp.specification.Exchange.Declare( + exchange=exchange_name, + exchange_type=type_name, + passive=passive, + durable=durable, + auto_delete=auto_delete, + nowait=no_wait, + arguments=arguments + ) + + async with self._exchange_declare_lock: + return (await self._write_frame_awaiting_response( + 'exchange_declare', self.channel_id, request, no_wait)) + + async def exchange_declare_ok(self, frame): + future = self._get_waiter('exchange_declare') + future.set_result(True) + logger.debug("Exchange declared") + return future + + async def exchange_delete(self, exchange_name, if_unused=False, no_wait=False): + request = pamqp.specification.Exchange.Delete(exchange=exchange_name, if_unused=if_unused, nowait=no_wait) + return await self._write_frame_awaiting_response( + 'exchange_delete', self.channel_id, request, no_wait) + + async def exchange_delete_ok(self, frame): + future = self._get_waiter('exchange_delete') + future.set_result(True) + logger.debug("Exchange deleted") + + async def exchange_bind(self, exchange_destination, exchange_source, routing_key, + no_wait=False, arguments=None): + if arguments is None: + arguments = {} + request = pamqp.specification.Exchange.Bind( + destination=exchange_destination, + source=exchange_source, + routing_key=routing_key, + nowait=no_wait, + arguments=arguments + ) + return (await self._write_frame_awaiting_response( + 'exchange_bind', self.channel_id, request, no_wait)) + + async def exchange_bind_ok(self, frame): + future = self._get_waiter('exchange_bind') + future.set_result(True) + logger.debug("Exchange bound") + + async def exchange_unbind(self, exchange_destination, exchange_source, routing_key, + no_wait=False, arguments=None): + if arguments is None: + arguments = {} + + request = pamqp.specification.Exchange.Unbind( + destination=exchange_destination, + source=exchange_source, + routing_key=routing_key, + nowait=no_wait, + arguments=arguments, + ) + return (await self._write_frame_awaiting_response( + 'exchange_unbind', self.channel_id, request, no_wait)) + + async def exchange_unbind_ok(self, frame): + future = self._get_waiter('exchange_unbind') + future.set_result(True) + logger.debug("Exchange bound") + +# +## Queue class implementation +# + + async def queue_declare(self, queue_name=None, passive=False, durable=False, + exclusive=False, auto_delete=False, no_wait=False, arguments=None): + """Create or check a queue on the broker + Args: + queue_name: str, the queue to receive message from. + The server generate a queue_name if not specified. + passive: bool, if set, the server will reply with + Declare-Ok if the queue already exists with the same name, and + raise an error if not. Checks for the same parameter as well. + durable: bool: If set when creating a new queue, the queue + will be marked as durable. Durable queues remain active when a + server restarts. + exclusive: bool, request exclusive consumer access, + meaning only this consumer can access the queue + no_wait: bool, if set, the server will not respond to the method + arguments: dict, AMQP arguments to be passed when creating + the queue. + """ + if arguments is None: + arguments = {} + + if not queue_name: + queue_name = 'aioamqp.gen-' + str(uuid.uuid4()) + request = pamqp.specification.Queue.Declare( + queue=queue_name, + passive=passive, + durable=durable, + exclusive=exclusive, + auto_delete=auto_delete, + nowait=no_wait, + arguments=arguments + ) + return (await self._write_frame_awaiting_response( + 'queue_declare' + queue_name, self.channel_id, request, no_wait)) + + async def queue_declare_ok(self, frame): + results = { + 'queue': frame.queue, + 'message_count': frame.message_count, + 'consumer_count': frame.consumer_count, + } + future = self._get_waiter('queue_declare' + results['queue']) + future.set_result(results) + logger.debug("Queue declared") + + async def queue_delete(self, queue_name, if_unused=False, if_empty=False, no_wait=False): + """Delete a queue in RabbitMQ + Args: + queue_name: str, the queue to receive message from + if_unused: bool, the queue is deleted if it has no consumers. Raise if not. + if_empty: bool, the queue is deleted if it has no messages. Raise if not. + no_wait: bool, if set, the server will not respond to the method + """ + request = pamqp.specification.Queue.Delete( + queue=queue_name, + if_unused=if_unused, + if_empty=if_empty, + nowait=no_wait + ) + return (await self._write_frame_awaiting_response( + 'queue_delete', self.channel_id, request, no_wait)) + + async def queue_delete_ok(self, frame): + future = self._get_waiter('queue_delete') + future.set_result(True) + logger.debug("Queue deleted") + + async def queue_bind(self, queue_name, exchange_name, routing_key, no_wait=False, arguments=None): + """Bind a queue and a channel.""" + if arguments is None: + arguments = {} + + request = pamqp.specification.Queue.Bind( + queue=queue_name, + exchange=exchange_name, + routing_key=routing_key, + nowait=no_wait, + arguments=arguments + ) + # short reserved-1 + async with self._queue_bind_lock: + return (await self._write_frame_awaiting_response( + 'queue_bind', self.channel_id, request, no_wait)) + + async def queue_bind_ok(self, frame): + future = self._get_waiter('queue_bind') + future.set_result(True) + logger.debug("Queue bound") + + async def queue_unbind(self, queue_name, exchange_name, routing_key, arguments=None): + if arguments is None: + arguments = {} + + request = pamqp.specification.Queue.Unbind( + queue=queue_name, + exchange=exchange_name, + routing_key=routing_key, + arguments=arguments + ) + + return (await self._write_frame_awaiting_response( + 'queue_unbind', self.channel_id, request, no_wait=False)) + + async def queue_unbind_ok(self, frame): + future = self._get_waiter('queue_unbind') + future.set_result(True) + logger.debug("Queue unbound") + + async def queue_purge(self, queue_name, no_wait=False): + request = pamqp.specification.Queue.Purge( + queue=queue_name, nowait=no_wait + ) + return (await self._write_frame_awaiting_response( + 'queue_purge', self.channel_id, request, no_wait=no_wait)) + + async def queue_purge_ok(self, frame): + future = self._get_waiter('queue_purge') + future.set_result({'message_count': frame.message_count}) + +# +## Basic class implementation +# + + async def basic_publish(self, payload, exchange_name, routing_key, + properties=None, mandatory=False, immediate=False): + if isinstance(payload, str): + warnings.warn("Str payload support will be removed in next release", DeprecationWarning) + payload = payload.encode() + + if properties is None: + properties = {} + + method_request = pamqp.specification.Basic.Publish( + exchange=exchange_name, + routing_key=routing_key, + mandatory=mandatory, + immediate=immediate + ) + + await self._write_frame(self.channel_id, method_request, drain=False) + + header_request = pamqp.header.ContentHeader( + body_size=len(payload), + properties=pamqp.specification.Basic.Properties(**properties) + ) + await self._write_frame(self.channel_id, header_request, drain=False) + + # split the payload + + frame_max = self.protocol.server_frame_max or len(payload) + for chunk in (payload[0+i:frame_max+i] for i in range(0, len(payload), frame_max)): + content_request = pamqp.body.ContentBody(chunk) + await self._write_frame(self.channel_id, content_request, drain=False) + + await self.protocol._drain() + + async def basic_qos(self, prefetch_size=0, prefetch_count=0, connection_global=False): + """Specifies quality of service. + + Args: + prefetch_size: int, request that messages be sent in advance so + that when the client finishes processing a message, the + following message is already held locally + prefetch_count: int: Specifies a prefetch window in terms of + whole messages. This field may be used in combination with the + prefetch-size field; a message will only be sent in advance if + both prefetch windows (and those at the channel and connection + level) allow it + connection_global: bool: global=false means that the QoS + settings should apply per-consumer channel; and global=true to mean + that the QoS settings should apply per-channel. + """ + request = pamqp.specification.Basic.Qos( + prefetch_size, prefetch_count, connection_global + ) + return (await self._write_frame_awaiting_response( + 'basic_qos', self.channel_id, request, no_wait=False) + ) + + async def basic_qos_ok(self, frame): + future = self._get_waiter('basic_qos') + future.set_result(True) + logger.debug("Qos ok") + + + async def basic_server_nack(self, frame, delivery_tag=None): + if delivery_tag is None: + delivery_tag = frame.delivery_tag + fut = self._get_waiter('basic_server_ack_{}'.format(delivery_tag)) + logger.debug('Received nack for delivery tag %r', delivery_tag) + fut.set_exception(exceptions.PublishFailed(delivery_tag)) + + async def basic_consume(self, callback, queue_name='', consumer_tag='', no_local=False, no_ack=False, + exclusive=False, no_wait=False, arguments=None): + """Starts the consumption of message into a queue. + the callback will be called each time we're receiving a message. + + Args: + callback: coroutine, the called callback + queue_name: str, the queue to receive message from + consumer_tag: str, optional consumer tag + no_local: bool, if set the server will not send messages + to the connection that published them. + no_ack: bool, if set the server does not expect + acknowledgements for messages + exclusive: bool, request exclusive consumer access, + meaning only this consumer can access the queue + no_wait: bool, if set, the server will not respond to the method + arguments: dict, AMQP arguments to be passed to the server + """ + # If a consumer tag was not passed, create one + consumer_tag = consumer_tag or 'ctag%i.%s' % (self.channel_id, uuid.uuid4().hex) + + if arguments is None: + arguments = {} + + request = pamqp.specification.Basic.Consume( + queue=queue_name, + consumer_tag=consumer_tag, + no_local=no_local, + no_ack=no_ack, + exclusive=exclusive, + nowait=no_wait, + arguments=arguments + ) + + self.consumer_callbacks[consumer_tag] = callback + self.last_consumer_tag = consumer_tag + + return_value = await self._write_frame_awaiting_response( + 'basic_consume' + consumer_tag, self.channel_id, request, no_wait) + if no_wait: + return_value = {'consumer_tag': consumer_tag} + else: + self._ctag_events[consumer_tag].set() + return return_value + + async def basic_consume_ok(self, frame): + ctag = frame.consumer_tag + results = { + 'consumer_tag': ctag, + } + future = self._get_waiter('basic_consume' + ctag) + future.set_result(results) + self._ctag_events[ctag] = asyncio.Event() + + async def basic_deliver(self, frame): + consumer_tag = frame.consumer_tag + delivery_tag = frame.delivery_tag + is_redeliver = frame.redelivered + exchange_name = frame.exchange + routing_key = frame.routing_key + _channel, content_header_frame = await self.protocol.get_frame() + + buffer = io.BytesIO() + + while(buffer.tell() < content_header_frame.body_size): + _channel, content_body_frame = await self.protocol.get_frame() + buffer.write(content_body_frame.value) + + body = buffer.getvalue() + envelope = Envelope(consumer_tag, delivery_tag, exchange_name, routing_key, is_redeliver) + properties = amqp_properties.from_pamqp(content_header_frame.properties) + + callback = self.consumer_callbacks[consumer_tag] + + event = self._ctag_events.get(consumer_tag) + if event: + await event.wait() + del self._ctag_events[consumer_tag] + + await callback(self, body, envelope, properties) + + async def server_basic_cancel(self, frame): + # https://www.rabbitmq.com/consumer-cancel.html + consumer_tag = frame.consumer_tag + _no_wait = frame.nowait + self.cancelled_consumers.add(consumer_tag) + logger.info("consume cancelled received") + for callback in self.cancellation_callbacks: + try: + await callback(self, consumer_tag) + except Exception as error: # pylint: disable=broad-except + logger.error("cancellation callback %r raised exception %r", + callback, error) + + async def basic_cancel(self, consumer_tag, no_wait=False): + request = pamqp.specification.Basic.Cancel(consumer_tag, no_wait) + return (await self._write_frame_awaiting_response( + 'basic_cancel', self.channel_id, request, no_wait=no_wait) + ) + + async def basic_cancel_ok(self, frame): + results = { + 'consumer_tag': frame.consumer_tag, + } + future = self._get_waiter('basic_cancel') + future.set_result(results) + logger.debug("Cancel ok") + + async def basic_get(self, queue_name='', no_ack=False): + request = pamqp.specification.Basic.Get(queue=queue_name, no_ack=no_ack) + return (await self._write_frame_awaiting_response( + 'basic_get', self.channel_id, request, no_wait=False) + ) + + async def basic_get_ok(self, frame): + data = { + 'delivery_tag': frame.delivery_tag, + 'redelivered': frame.redelivered, + 'exchange_name': frame.exchange, + 'routing_key': frame.routing_key, + 'message_count': frame.message_count, + } + + _channel, content_header_frame = await self.protocol.get_frame() + + buffer = io.BytesIO() + while(buffer.tell() < content_header_frame.body_size): + _channel, content_body_frame = await self.protocol.get_frame() + buffer.write(content_body_frame.value) + + data['message'] = buffer.getvalue() + data['properties'] = amqp_properties.from_pamqp(content_header_frame.properties) + future = self._get_waiter('basic_get') + future.set_result(data) + + async def basic_get_empty(self, frame): + future = self._get_waiter('basic_get') + future.set_exception(exceptions.EmptyQueue) + + async def basic_client_ack(self, delivery_tag, multiple=False): + request = pamqp.specification.Basic.Ack(delivery_tag, multiple) + await self._write_frame(self.channel_id, request) + + async def basic_client_nack(self, delivery_tag, multiple=False, requeue=True): + request = pamqp.specification.Basic.Nack(delivery_tag, multiple, requeue) + await self._write_frame(self.channel_id, request) + + async def basic_server_ack(self, frame): + delivery_tag = frame.delivery_tag + fut = self._get_waiter('basic_server_ack_{}'.format(delivery_tag)) + logger.debug('Received ack for delivery tag %s', delivery_tag) + fut.set_result(True) + + async def basic_reject(self, delivery_tag, requeue=False): + request = pamqp.specification.Basic.Reject(delivery_tag, requeue) + await self._write_frame(self.channel_id, request) + + async def basic_recover_async(self, requeue=True): + request = pamqp.specification.Basic.RecoverAsync(requeue) + await self._write_frame(self.channel_id, request) + + async def basic_recover(self, requeue=True): + request = pamqp.specification.Basic.Recover(requeue) + return (await self._write_frame_awaiting_response( + 'basic_recover', self.channel_id, request, no_wait=False) + ) + + async def basic_recover_ok(self, frame): + future = self._get_waiter('basic_recover') + future.set_result(True) + logger.debug("Cancel ok") + + async def basic_return(self, frame): + reply_code = frame.reply_code + reply_text = frame.reply_text + exchange_name = frame.exchange + routing_key = frame.routing_key + _channel, content_header_frame = await self.protocol.get_frame() + + buffer = io.BytesIO() + while(buffer.tell() < content_header_frame.body_size): + _channel, content_body_frame = await self.protocol.get_frame() + buffer.write(content_body_frame.value) + + body = buffer.getvalue() + envelope = ReturnEnvelope(reply_code, reply_text, + exchange_name, routing_key) + properties = amqp_properties.from_pamqp(content_header_frame.properties) + callback = self.return_callback + if callback is None: + # they have set mandatory bit, but havent added a callback + logger.warning('You have received a returned message, but dont have a callback registered for returns.' + ' Please set channel.return_callback') + else: + await callback(self, body, envelope, properties) + + +# +## convenient aliases +# + queue = queue_declare + exchange = exchange_declare + + async def publish(self, payload, exchange_name, routing_key, properties=None, mandatory=False, immediate=False): + if isinstance(payload, str): + warnings.warn("Str payload support will be removed in next release", DeprecationWarning) + payload = payload.encode() + + if properties is None: + properties = {} + + if self.publisher_confirms: + delivery_tag = next(self.delivery_tag_iter) # pylint: disable=stop-iteration-return + fut = self._set_waiter('basic_server_ack_{}'.format(delivery_tag)) + + method_request = pamqp.specification.Basic.Publish( + exchange=exchange_name, + routing_key=routing_key, + mandatory=mandatory, + immediate=immediate + ) + await self._write_frame(self.channel_id, method_request, drain=False) + + properties = pamqp.specification.Basic.Properties(**properties) + header_request = pamqp.header.ContentHeader( + body_size=len(payload), properties=properties + ) + await self._write_frame(self.channel_id, header_request, drain=False) + + # split the payload + + frame_max = self.protocol.server_frame_max or len(payload) + for chunk in (payload[0+i:frame_max+i] for i in range(0, len(payload), frame_max)): + content_request = pamqp.body.ContentBody(chunk) + await self._write_frame(self.channel_id, content_request, drain=False) + + await self.protocol._drain() + + if self.publisher_confirms: + await fut + + async def confirm_select(self, *, no_wait=False): + if self.publisher_confirms: + raise ValueError('publisher confirms already enabled') + request = pamqp.specification.Confirm.Select(nowait=no_wait) + + return (await self._write_frame_awaiting_response( + 'confirm_select', self.channel_id, request, no_wait) + ) + + async def confirm_select_ok(self, frame): + self.publisher_confirms = True + self.delivery_tag_iter = count(1) + fut = self._get_waiter('confirm_select') + fut.set_result(True) + logger.debug("Confirm selected") + + def add_cancellation_callback(self, callback): + """Add a callback that is invoked when a consumer is cancelled. + + :param callback: function to call + + `callback` is called with the channel and consumer tag as positional + parameters. The callback can be either a plain callable or an + asynchronous co-routine. + + """ + self.cancellation_callbacks.append(callback) diff --git a/aioamqp/protocol.py b/aioamqp/protocol.py new file mode 100644 index 0000000..75d1815 --- /dev/null +++ b/aioamqp/protocol.py @@ -0,0 +1,469 @@ +""" + Amqp Protocol +""" + +import asyncio +import logging + +import pamqp.frame +import pamqp.heartbeat +import pamqp.specification + +from . import channel as amqp_channel +from . import constants as amqp_constants +from . import frame as amqp_frame +from . import exceptions +from . import version + + +logger = logging.getLogger(__name__) + + +CONNECTING, OPEN, CLOSING, CLOSED = range(4) + + +class _StreamWriter(asyncio.StreamWriter): + + def write(self, data): + super().write(data) + self._protocol._heartbeat_timer_send_reset() + + def writelines(self, data): + super().writelines(data) + self._protocol._heartbeat_timer_send_reset() + + def write_eof(self): + ret = super().write_eof() + self._protocol._heartbeat_timer_send_reset() + return ret + + +class AmqpProtocol(asyncio.StreamReaderProtocol): + """The AMQP protocol for asyncio. + + See http://docs.python.org/3.4/library/asyncio-protocol.html#protocols for more information + on asyncio's protocol API. + + """ + + CHANNEL_FACTORY = amqp_channel.Channel + + def __init__(self, *args, **kwargs): + """Defines our new protocol instance + + Args: + channel_max: int, specifies highest channel number that the server permits. + Usable channel numbers are in the range 1..channel-max. + Zero indicates no specified limit. + frame_max: int, the largest frame size that the server proposes for the connection, + including frame header and end-byte. The client can negotiate a lower value. + Zero means that the server does not impose any specific limit + but may reject very large frames if it cannot allocate resources for them. + heartbeat: int, the delay, in seconds, of the connection heartbeat that the server wants. + Zero means the server does not want a heartbeat. + loop: Asyncio.Eventloop: specify the eventloop to use. + client_properties: dict, client-props to tune the client identification + """ + self._loop = kwargs.get('loop') or asyncio.get_event_loop() + self._reader = asyncio.StreamReader(loop=self._loop) + super().__init__(self._reader, loop=self._loop) + self._on_error_callback = kwargs.get('on_error') + + self.client_properties = kwargs.get('client_properties', {}) + self.connection_tunning = {} + if 'channel_max' in kwargs: + self.connection_tunning['channel_max'] = kwargs.get('channel_max') + if 'frame_max' in kwargs: + self.connection_tunning['frame_max'] = kwargs.get('frame_max') + if 'heartbeat' in kwargs: + self.connection_tunning['heartbeat'] = kwargs.get('heartbeat') + + self.connecting = asyncio.Future(loop=self._loop) + self.connection_closed = asyncio.Event() + self.stop_now = asyncio.Future(loop=self._loop) + self.state = CONNECTING + self.version_major = None + self.version_minor = None + self.server_properties = None + self.server_mechanisms = None + self.server_locales = None + self.worker = None + self.server_heartbeat = None + self._heartbeat_timer_recv = None + self._heartbeat_timer_send = None + self._heartbeat_trigger_send = asyncio.Event() + self._heartbeat_worker = None + self.channels = {} + self.server_frame_max = None + self.server_channel_max = None + self.channels_ids_ceil = 0 + self.channels_ids_free = set() + self._drain_lock = asyncio.Lock() + + def connection_made(self, transport): + super().connection_made(transport) + self._stream_writer = _StreamWriter(transport, self, self._stream_reader, self._loop) + + def eof_received(self): + super().eof_received() + # Python 3.5+ started returning True here to keep the transport open. + # We really couldn't care less so keep the behavior from 3.4 to make + # sure connection_lost() is called. + return False + + def connection_lost(self, exc): + if exc is not None: + logger.warning("Connection lost exc=%r", exc) + self.connection_closed.set() + self.state = CLOSED + self._close_channels(exception=exc) + self._heartbeat_stop() + super().connection_lost(exc) + + def data_received(self, data): + self._heartbeat_timer_recv_reset() + super().data_received(data) + + async def ensure_open(self): + # Raise a suitable exception if the connection isn't open. + # Handle cases from the most common to the least common. + + if self.state == OPEN: + return + + if self.state == CLOSED: + raise exceptions.AmqpClosedConnection() + + # If the closing handshake is in progress, let it complete. + if self.state == CLOSING: + await self.wait_closed() + raise exceptions.AmqpClosedConnection() + + # Control may only reach this point in buggy third-party subclasses. + assert self.state == CONNECTING + raise exceptions.AioamqpException("connection isn't established yet.") + + async def _drain(self): + async with self._drain_lock: + # drain() cannot be called concurrently by multiple coroutines: + # http://bugs.python.org/issue29930. Remove this lock when no + # version of Python where this bugs exists is supported anymore. + await self._stream_writer.drain() + + async def _write_frame(self, channel_id, request, drain=True): + amqp_frame.write(self._stream_writer, channel_id, request) + if drain: + await self._drain() + + async def close(self, no_wait=False, timeout=None): + """Close connection (and all channels)""" + await self.ensure_open() + self.state = CLOSING + request = pamqp.specification.Connection.Close( + reply_code=0, + reply_text='', + class_id=0, + method_id=0 + ) + + await self._write_frame(0, request) + if not no_wait: + await self.wait_closed(timeout=timeout) + + async def wait_closed(self, timeout=None): + await asyncio.wait_for(self.connection_closed.wait(), timeout=timeout) + if self._heartbeat_worker is not None: + try: + await asyncio.wait_for(self._heartbeat_worker, timeout=timeout) + except asyncio.CancelledError: + pass + + async def close_ok(self, frame): + self._stream_writer.close() + + async def start_connection(self, host, port, login, password, virtualhost, ssl=False, + login_method='PLAIN', insist=False): + """Initiate a connection at the protocol level + We send `PROTOCOL_HEADER' + """ + + if login_method != 'PLAIN': + logger.warning('login_method %s is not supported, falling back to PLAIN', login_method) + + self._stream_writer.write(amqp_constants.PROTOCOL_HEADER) + + # Wait 'start' method from the server + await self.dispatch_frame() + + client_properties = { + 'capabilities': { + 'consumer_cancel_notify': True, + 'connection.blocked': False, + }, + 'copyright': 'BSD', + 'product': version.__package__, + 'product_version': version.__version__, + } + client_properties.update(self.client_properties) + auth = { + 'LOGIN': login, + 'PASSWORD': password, + } + + # waiting reply start with credentions and co + await self.start_ok(client_properties, 'PLAIN', auth, self.server_locales) + + # wait for a "tune" reponse + await self.dispatch_frame() + + tune_ok = { + 'channel_max': self.connection_tunning.get('channel_max', self.server_channel_max), + 'frame_max': self.connection_tunning.get('frame_max', self.server_frame_max), + 'heartbeat': self.connection_tunning.get('heartbeat', self.server_heartbeat), + } + # "tune" the connexion with max channel, max frame, heartbeat + await self.tune_ok(**tune_ok) + + # update connection tunning values + self.server_frame_max = tune_ok['frame_max'] + self.server_channel_max = tune_ok['channel_max'] + self.server_heartbeat = tune_ok['heartbeat'] + + if self.server_heartbeat > 0: + self._heartbeat_timer_recv_reset() + self._heartbeat_timer_send_reset() + + # open a virtualhost + await self.open(virtualhost, capabilities='', insist=insist) + + # wait for open-ok + channel, frame = await self.get_frame() + await self.dispatch_frame(channel, frame) + + await self.ensure_open() + # for now, we read server's responses asynchronously + self.worker = asyncio.ensure_future(self.run(), loop=self._loop) + + async def get_frame(self): + """Read the frame, and only decode its header + + """ + return await amqp_frame.read(self._stream_reader) + + async def dispatch_frame(self, frame_channel=None, frame=None): + """Dispatch the received frame to the corresponding handler""" + + method_dispatch = { + pamqp.specification.Connection.Close.name: self.server_close, + pamqp.specification.Connection.CloseOk.name: self.close_ok, + pamqp.specification.Connection.Tune.name: self.tune, + pamqp.specification.Connection.Start.name: self.start, + pamqp.specification.Connection.OpenOk.name: self.open_ok, + } + if frame_channel is None and frame is None: + frame_channel, frame = await self.get_frame() + + if isinstance(frame, pamqp.heartbeat.Heartbeat): + return + + if frame_channel != 0: + channel = self.channels.get(frame_channel) + if channel is not None: + await channel.dispatch_frame(frame) + else: + logger.info("Unknown channel %s", frame_channel) + return + + if frame.name not in method_dispatch: + logger.info("frame %s is not handled", frame.name) + return + await method_dispatch[frame.name](frame) + + def release_channel_id(self, channel_id): + """Called from the channel instance, it relase a previously used + channel_id + """ + self.channels_ids_free.add(channel_id) + + @property + def channels_ids_count(self): + return self.channels_ids_ceil - len(self.channels_ids_free) + + def _close_channels(self, reply_code=None, reply_text=None, exception=None): + """Cleanly close channels + + Args: + reply_code: int, the amqp error code + reply_text: str, the text associated to the error_code + exc: the exception responsible of this error + + """ + if exception is None: + exception = exceptions.ChannelClosed(reply_code, reply_text) + + if self._on_error_callback: + if asyncio.iscoroutinefunction(self._on_error_callback): + asyncio.ensure_future(self._on_error_callback(exception), loop=self._loop) + else: + self._on_error_callback(exceptions.ChannelClosed(exception)) + + for channel in self.channels.values(): + channel.connection_closed(reply_code, reply_text, exception) + + async def run(self): + while not self.stop_now.done(): + try: + await self.dispatch_frame() + except exceptions.AmqpClosedConnection as exc: + logger.info("Close connection") + self.stop_now.set_result(None) + + self._close_channels(exception=exc) + except Exception: # pylint: disable=broad-except + logger.exception('error on dispatch') + + async def heartbeat(self): + """ deprecated heartbeat coroutine + + This coroutine is now a no-op as the heartbeat is handled directly by + the rest of the AmqpProtocol class. This is kept around for backwards + compatibility purposes only. + """ + await self.stop_now + + async def send_heartbeat(self): + """Sends an heartbeat message. + It can be an ack for the server or the client willing to check for the + connexion timeout + """ + request = pamqp.heartbeat.Heartbeat() + await self._write_frame(0, request) + + def _heartbeat_timer_recv_timeout(self): + # 4.2.7 If a peer detects no incoming traffic (i.e. received octets) for + # two heartbeat intervals or longer, it should close the connection + # without following the Connection.Close/Close-Ok handshaking, and log + # an error. + # TODO(rcardona) raise a "timeout" exception somewhere + self._stream_writer.close() + + def _heartbeat_timer_recv_reset(self): + if self.server_heartbeat is None: + return + if self._heartbeat_timer_recv is not None: + self._heartbeat_timer_recv.cancel() + self._heartbeat_timer_recv = self._loop.call_later( + self.server_heartbeat * 2, + self._heartbeat_timer_recv_timeout) + + def _heartbeat_timer_send_reset(self): + if self.server_heartbeat is None: + return + if self._heartbeat_timer_send is not None: + self._heartbeat_timer_send.cancel() + self._heartbeat_timer_send = self._loop.call_later( + self.server_heartbeat, + self._heartbeat_trigger_send.set) + if self._heartbeat_worker is None: + self._heartbeat_worker = asyncio.ensure_future(self._heartbeat(), loop=self._loop) + + def _heartbeat_stop(self): + self.server_heartbeat = None + if self._heartbeat_timer_recv is not None: + self._heartbeat_timer_recv.cancel() + if self._heartbeat_timer_send is not None: + self._heartbeat_timer_send.cancel() + if self._heartbeat_worker is not None: + self._heartbeat_worker.cancel() + + async def _heartbeat(self): + while self.state != CLOSED: + await self._heartbeat_trigger_send.wait() + self._heartbeat_trigger_send.clear() + await self.send_heartbeat() + + # Amqp specific methods + async def start(self, frame): + """Method sent from the server to begin a new connection""" + self.version_major = frame.version_major + self.version_minor = frame.version_minor + self.server_properties = frame.server_properties + self.server_mechanisms = frame.mechanisms + self.server_locales = frame.locales + + async def start_ok(self, client_properties, mechanism, auth, locale): + def credentials(): + return '\0{LOGIN}\0{PASSWORD}'.format(**auth) + + request = pamqp.specification.Connection.StartOk( + client_properties=client_properties, + mechanism=mechanism, + locale=locale, + response=credentials() + ) + await self._write_frame(0, request) + + async def server_close(self, frame): + """The server is closing the connection""" + self.state = CLOSING + reply_code = frame.reply_code + reply_text = frame.reply_text + class_id = frame.class_id + method_id = frame.method_id + logger.warning("Server closed connection: %s, code=%s, class_id=%s, method_id=%s", + reply_text, reply_code, class_id, method_id) + self._close_channels(reply_code, reply_text) + await self._close_ok() + self._stream_writer.close() + + async def _close_ok(self): + request = pamqp.specification.Connection.CloseOk() + await self._write_frame(0, request) + + async def tune(self, frame): + self.server_channel_max = frame.channel_max + self.server_frame_max = frame.frame_max + self.server_heartbeat = frame.heartbeat + + async def tune_ok(self, channel_max, frame_max, heartbeat): + request = pamqp.specification.Connection.TuneOk( + channel_max, frame_max, heartbeat + ) + await self._write_frame(0, request) + + async def secure_ok(self, login_response): + pass + + async def open(self, virtual_host, capabilities='', insist=False): + """Open connection to virtual host.""" + request = pamqp.specification.Connection.Open( + virtual_host, capabilities, insist + ) + await self._write_frame(0, request) + + async def open_ok(self, frame): + self.state = OPEN + logger.info("Recv open ok") + + # + ## aioamqp public methods + # + + async def channel(self, **kwargs): + """Factory to create a new channel + + """ + await self.ensure_open() + try: + channel_id = self.channels_ids_free.pop() + except KeyError as ex: + assert self.server_channel_max is not None, 'connection channel-max tuning not performed' + # channel-max = 0 means no limit + if self.server_channel_max and self.channels_ids_ceil > self.server_channel_max: + raise exceptions.NoChannelAvailable() from ex + self.channels_ids_ceil += 1 + channel_id = self.channels_ids_ceil + channel = self.CHANNEL_FACTORY(self, channel_id, **kwargs) + self.channels[channel_id] = channel + await channel.open() + return channel diff --git a/aioamqp/tests/test_basic.py b/aioamqp/tests/test_basic.py new file mode 100644 index 0000000..00ec880 --- /dev/null +++ b/aioamqp/tests/test_basic.py @@ -0,0 +1,217 @@ +""" + Amqp basic class tests +""" + +import asyncio +import asynctest + +from . import testcase +from .. import exceptions +from .. import properties + + +class QosTestCase(testcase.RabbitTestCaseMixin, asynctest.TestCase): + + async def test_basic_qos_default_args(self): + result = await self.channel.basic_qos() + self.assertTrue(result) + + async def test_basic_qos(self): + result = await self.channel.basic_qos( + prefetch_size=0, + prefetch_count=100, + connection_global=False) + + self.assertTrue(result) + + async def test_basic_qos_prefetch_size(self): + with self.assertRaises(exceptions.ChannelClosed) as cm: + await self.channel.basic_qos( + prefetch_size=10, + prefetch_count=100, + connection_global=False) + + self.assertEqual(cm.exception.code, 540) + + async def test_basic_qos_wrong_values(self): + with self.assertRaises(TypeError): + await self.channel.basic_qos( + prefetch_size=100000, + prefetch_count=1000000000, + connection_global=False) + + +class BasicCancelTestCase(testcase.RabbitTestCaseMixin, asynctest.TestCase): + + async def test_basic_cancel(self): + + async def callback(channel, body, envelope, _properties): + pass + + queue_name = 'queue_name' + exchange_name = 'exchange_name' + await self.channel.queue_declare(queue_name) + await self.channel.exchange_declare(exchange_name, type_name='direct') + await self.channel.queue_bind(queue_name, exchange_name, routing_key='') + result = await self.channel.basic_consume(callback, queue_name=queue_name) + result = await self.channel.basic_cancel(result['consumer_tag']) + + result = await self.channel.publish("payload", exchange_name, routing_key='') + + await asyncio.sleep(5) + + result = await self.channel.queue_declare(queue_name, passive=True) + self.assertEqual(result['message_count'], 1) + self.assertEqual(result['consumer_count'], 0) + + + async def test_basic_cancel_unknown_ctag(self): + result = await self.channel.basic_cancel("unknown_ctag") + self.assertTrue(result) + + +class BasicGetTestCase(testcase.RabbitTestCaseMixin, asynctest.TestCase): + + + async def test_basic_get(self): + queue_name = 'queue_name' + exchange_name = 'exchange_name' + routing_key = '' + + await self.channel.queue_declare(queue_name) + await self.channel.exchange_declare(exchange_name, type_name='direct') + await self.channel.queue_bind(queue_name, exchange_name, routing_key=routing_key) + + await self.channel.publish("payload", exchange_name, routing_key=routing_key) + + result = await self.channel.basic_get(queue_name) + self.assertEqual(result['routing_key'], routing_key) + self.assertFalse(result['redelivered']) + self.assertIn('delivery_tag', result) + self.assertEqual(result['exchange_name'].split('.')[-1], exchange_name) + self.assertEqual(result['message'], b'payload') + self.assertIsInstance(result['properties'], properties.Properties) + + async def test_basic_get_empty(self): + queue_name = 'queue_name' + exchange_name = 'exchange_name' + routing_key = '' + await self.channel.queue_declare(queue_name) + await self.channel.exchange_declare(exchange_name, type_name='direct') + await self.channel.queue_bind(queue_name, exchange_name, routing_key=routing_key) + + with self.assertRaises(exceptions.EmptyQueue): + await self.channel.basic_get(queue_name) + + +class BasicDeliveryTestCase(testcase.RabbitTestCaseMixin, asynctest.TestCase): + + + async def publish(self, queue_name, exchange_name, routing_key, payload): + await self.channel.queue_declare(queue_name, exclusive=False, no_wait=False) + await self.channel.exchange_declare(exchange_name, type_name='fanout') + await self.channel.queue_bind(queue_name, exchange_name, routing_key=routing_key) + await self.channel.publish(payload, exchange_name, queue_name) + + + + async def test_ack_message(self): + queue_name = 'queue_name' + exchange_name = 'exchange_name' + routing_key = '' + + await self.publish( + queue_name, exchange_name, routing_key, "payload" + ) + + qfuture = asyncio.Future(loop=self.loop) + + async def qcallback(channel, body, envelope, _properties): + qfuture.set_result(envelope) + + await self.channel.basic_consume(qcallback, queue_name=queue_name) + envelope = await qfuture + + await qfuture + await self.channel.basic_client_ack(envelope.delivery_tag) + + async def test_basic_nack(self): + queue_name = 'queue_name' + exchange_name = 'exchange_name' + routing_key = '' + + await self.publish( + queue_name, exchange_name, routing_key, "payload" + ) + + qfuture = asyncio.Future(loop=self.loop) + + async def qcallback(channel, body, envelope, _properties): + await self.channel.basic_client_nack( + envelope.delivery_tag, multiple=True, requeue=False + ) + qfuture.set_result(True) + + await self.channel.basic_consume(qcallback, queue_name=queue_name) + await qfuture + + async def test_basic_nack_norequeue(self): + queue_name = 'queue_name' + exchange_name = 'exchange_name' + routing_key = '' + + await self.publish( + queue_name, exchange_name, routing_key, "payload" + ) + + qfuture = asyncio.Future(loop=self.loop) + + async def qcallback(channel, body, envelope, _properties): + await self.channel.basic_client_nack(envelope.delivery_tag, requeue=False) + qfuture.set_result(True) + + await self.channel.basic_consume(qcallback, queue_name=queue_name) + await qfuture + + async def test_basic_nack_requeue(self): + queue_name = 'queue_name' + exchange_name = 'exchange_name' + routing_key = '' + + await self.publish( + queue_name, exchange_name, routing_key, "payload" + ) + + qfuture = asyncio.Future(loop=self.loop) + called = False + + async def qcallback(channel, body, envelope, _properties): + nonlocal called + if not called: + called = True + await self.channel.basic_client_nack(envelope.delivery_tag, requeue=True) + else: + await self.channel.basic_client_ack(envelope.delivery_tag) + qfuture.set_result(True) + + await self.channel.basic_consume(qcallback, queue_name=queue_name) + await qfuture + + + async def test_basic_reject(self): + queue_name = 'queue_name' + exchange_name = 'exchange_name' + routing_key = '' + await self.publish( + queue_name, exchange_name, routing_key, "payload" + ) + + qfuture = asyncio.Future(loop=self.loop) + + async def qcallback(channel, body, envelope, _properties): + qfuture.set_result(envelope) + + await self.channel.basic_consume(qcallback, queue_name=queue_name) + envelope = await qfuture + + await self.channel.basic_reject(envelope.delivery_tag) diff --git a/aioamqp/tests/test_connection_lost.py b/aioamqp/tests/test_connection_lost.py new file mode 100644 index 0000000..d16d51b --- /dev/null +++ b/aioamqp/tests/test_connection_lost.py @@ -0,0 +1,30 @@ +import asynctest +import asynctest.mock +import asyncio + +from aioamqp.protocol import OPEN, CLOSED + +from . import testcase + + +class ConnectionLostTestCase(testcase.RabbitTestCaseMixin, asynctest.TestCase): + + _multiprocess_can_split_ = True + + async def test_connection_lost(self): + + self.callback_called = False + + def callback(*args, **kwargs): + self.callback_called = True + + amqp = self.amqp + amqp._on_error_callback = callback + channel = self.channel + self.assertEqual(amqp.state, OPEN) + self.assertTrue(channel.is_open) + amqp._stream_reader._transport.close() # this should have the same effect as the tcp connection being lost + await asyncio.wait_for(amqp.worker, 1) + self.assertEqual(amqp.state, CLOSED) + self.assertFalse(channel.is_open) + self.assertTrue(self.callback_called) diff --git a/aioamqp/tests/test_protocol.py b/aioamqp/tests/test_protocol.py new file mode 100644 index 0000000..80d4187 --- /dev/null +++ b/aioamqp/tests/test_protocol.py @@ -0,0 +1,89 @@ +""" + Test our Protocol class +""" +import asynctest +from unittest import mock + +from . import testcase +from .. import exceptions +from .. import connect as amqp_connect +from .. import from_url as amqp_from_url +from ..protocol import AmqpProtocol, OPEN + + +class ProtocolTestCase(testcase.RabbitTestCaseMixin, asynctest.TestCase): + + async def test_connect(self): + _transport, protocol = await amqp_connect( + host=self.host, port=self.port, virtualhost=self.vhost, loop=self.loop + ) + self.assertEqual(protocol.state, OPEN) + await protocol.close() + + async def test_connect_products_info(self): + client_properties = { + 'program': 'aioamqp-tests', + 'program_version': '0.1.1', + } + _transport, protocol = await amqp_connect( + host=self.host, + port=self.port, + virtualhost=self.vhost, + client_properties=client_properties, + loop=self.loop, + ) + + self.assertEqual(protocol.client_properties, client_properties) + await protocol.close() + + async def test_connection_unexistant_vhost(self): + with self.assertRaises(exceptions.AmqpClosedConnection): + await amqp_connect(host=self.host, port=self.port, virtualhost='/unexistant', loop=self.loop) + + def test_connection_wrong_login_password(self): + with self.assertRaises(exceptions.AmqpClosedConnection): + self.loop.run_until_complete( + amqp_connect(host=self.host, port=self.port, login='wrong', password='wrong', loop=self.loop) + ) + + async def test_connection_from_url(self): + with mock.patch('aioamqp.connect') as connect: + async def func(*x, **y): + return 1, 2 + connect.side_effect = func + await amqp_from_url('amqp://tom:pass@example.com:7777/myvhost', loop=self.loop) + connect.assert_called_once_with( + insist=False, + password='pass', + login_method='PLAIN', + login='tom', + host='example.com', + protocol_factory=AmqpProtocol, + virtualhost='myvhost', + port=7777, + loop=self.loop, + ) + + async def test_ssl_context_connection_from_url(self): + ssl_context = mock.Mock() + with mock.patch('aioamqp.connect') as connect: + async def func(*x, **y): + return 1, 2 + connect.side_effect = func + await amqp_from_url('amqps://tom:pass@example.com:7777/myvhost', loop=self.loop, ssl=ssl_context) + connect.assert_called_once_with( + insist=False, + password='pass', + login_method='PLAIN', + ssl=ssl_context, + login='tom', + host='example.com', + protocol_factory=AmqpProtocol, + virtualhost='myvhost', + port=7777, + loop=self.loop, + ) + + async def test_from_url_raises_on_wrong_scheme(self): + with self.assertRaises(ValueError): + await amqp_from_url('invalid://') diff --git a/async_amqp/protocol.py b/async_amqp/protocol.py index 32448ba..cbe433d 100644 --- a/async_amqp/protocol.py +++ b/async_amqp/protocol.py @@ -607,13 +607,13 @@ async def channel(self, **kwargs): await self.ensure_open() try: channel_id = self.channels_ids_free.pop() - except KeyError: + except KeyError as ex: assert self.server_channel_max is not None, \ 'connection channel-max tuning not performed' # channel-max = 0 means no limit if self.server_channel_max and \ self.channels_ids_ceil > self.server_channel_max: - raise exceptions.NoChannelAvailable() + raise exceptions.NoChannelAvailable() from ex self.channels_ids_ceil += 1 channel_id = self.channels_ids_ceil channel = self.CHANNEL_FACTORY(self, channel_id, **kwargs) diff --git a/docs/changelog.rst b/docs/changelog.rst index 0b43eda..08edf82 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -25,6 +25,10 @@ Trio-amqp 0.1 Next release ------------ + * Add support for Python 3.9. + * Drop support for Python 3.5. + * Fix annoying auth method warning because of a wrong defined default argument (closes #214). + Aioamqp 0.14.0 -------------- diff --git a/docs/introduction.rst b/docs/introduction.rst index 3a18a66..af9541d 100644 --- a/docs/introduction.rst +++ b/docs/introduction.rst @@ -8,7 +8,7 @@ AsyncAmqp library is a pure-Python implementation of the AMQP 0.9.1 protocol usi Prerequisites ------------- -AsyncAmqp works only with python >= 3.6 using asyncio, trio or curio. +AsyncAmqp works with python >= 3.7 using asyncio, trio or curio. Installation ------------ diff --git a/setup.cfg b/setup.cfg index 0ab7d0b..9a9f5ab 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,2 @@ [bdist_wheel] -python-tag = py35.py36.py37.py38 +python-tag = py36.py37.py38.py39 diff --git a/tox.ini b/tox.ini index 1950753..8393578 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py36, py37, py38 +envlist = py37, py38, py39 skipsdist = true skip_missing_interpreters = true