diff --git a/.gitignore b/.gitignore index 1502f8a..5d5f15f 100644 --- a/.gitignore +++ b/.gitignore @@ -58,3 +58,6 @@ docs/_build/ # editor stuffs *.swp /.pybuild + +# Pycharm/IntelliJ +.idea/* diff --git a/.pylintrc b/.pylintrc index d5b1838..7c6d5da 100644 --- a/.pylintrc +++ b/.pylintrc @@ -47,7 +47,8 @@ disable= too-many-locals, too-many-public-methods, too-many-statements, - unused-argument + unused-argument, + wrong-import-order [REPORTS] diff --git a/.travis.yml b/.travis.yml index f129d1f..07a7c9a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,10 @@ language: python +dist: bionic python: - 3.5 - 3.6 - 3.7 +- 3.8 services: - rabbitmq install: @@ -21,4 +23,3 @@ before_script: - ./rabbitmqadmin declare permission vhost=test user=guest read=".*" write=".*" configure=".*" script: - make test -- coverage diff --git a/AUTHORS.rst b/AUTHORS.rst index f3c9a69..2611cdf 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -1,4 +1,4 @@ -asyncamqp was originally created in early 2014 at Polyconseil. +asyncamqp was originally created as anyamqp in early 2014 at Polyconseil. AUTHORS are (and/or have been):: @@ -19,4 +19,10 @@ AUTHORS are (and/or have been):: * Alexander Gromyko * Nick Humrich * Pavel Kamaev - + * Mads Sejersen + * Dave Shawley + * Jacob Hagstedt P Suorra + * Matthias Urlichs + * Corey `notmeta` + * Paul Wistrand + * fullylegit diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..7ec4545 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,7 @@ +FROM python:3.5 + +WORKDIR /usr/src/app + +COPY . . + +RUN pip install -r requirements_dev.txt diff --git a/Makefile b/Makefile index e05970f..725d85c 100644 --- a/Makefile +++ b/Makefile @@ -38,6 +38,9 @@ test: update: pip install -r ci/requirements_dev.txt +pylint: + pylint aioamqp + ### semi-private targets used by polyconseil's CI (copy-pasted from blease) ### @@ -47,9 +50,9 @@ reports: mkdir -p reports jenkins-test: reports - $(MAKE) test TEST_OPTIONS="--with-coverage --cover-package=$(PACKAGE) \ - --cover-xml --cover-xml-file=reports/xmlcov.xml \ - --with-xunit --xunit-file=reports/TEST-$(PACKAGE).xml \ + $(MAKE) test TEST_OPTIONS="--cov=$(PACKAGE) \ + --cov-report xml:reports/xmlcov.xml \ + --junitxml=reports/TEST-$(PACKAGE).xml \ -v \ $(TEST_OPTIONS)" diff --git a/README.rst b/README.rst index 7f7cfa5..c7caa7f 100644 --- a/README.rst +++ b/README.rst @@ -46,6 +46,11 @@ Tests require an instance of RabbitMQ. You can start a new instance using docker Then you can run the tests with ``make test`` (requires ``pytest``). +tests using docker-compose +^^^^^^^^^^^^^^^^^^^^^^^^^^ +Start RabbitMQ using ``docker-compose up -d rabbitmq``. When RabbitMQ has started, start the tests using ``docker-compose up --build aioamqp-test`` + + Future work ----------- diff --git a/aioamqp/protocol.py b/aioamqp/protocol.py new file mode 100644 index 0000000..ddf1285 --- /dev/null +++ b/aioamqp/protocol.py @@ -0,0 +1,469 @@ +""" + Amqp Protocol +""" + +import asyncio +import logging + +import pamqp.frame +import pamqp.heartbeat +import pamqp.specification + +from . import channel as amqp_channel +from . import constants as amqp_constants +from . import frame as amqp_frame +from . import exceptions +from . import version + + +logger = logging.getLogger(__name__) + + +CONNECTING, OPEN, CLOSING, CLOSED = range(4) + + +class _StreamWriter(asyncio.StreamWriter): + + def write(self, data): + super().write(data) + self._protocol._heartbeat_timer_send_reset() + + def writelines(self, data): + super().writelines(data) + self._protocol._heartbeat_timer_send_reset() + + def write_eof(self): + ret = super().write_eof() + self._protocol._heartbeat_timer_send_reset() + return ret + + +class AmqpProtocol(asyncio.StreamReaderProtocol): + """The AMQP protocol for asyncio. + + See http://docs.python.org/3.4/library/asyncio-protocol.html#protocols for more information + on asyncio's protocol API. + + """ + + CHANNEL_FACTORY = amqp_channel.Channel + + def __init__(self, *args, **kwargs): + """Defines our new protocol instance + + Args: + channel_max: int, specifies highest channel number that the server permits. + Usable channel numbers are in the range 1..channel-max. + Zero indicates no specified limit. + frame_max: int, the largest frame size that the server proposes for the connection, + including frame header and end-byte. The client can negotiate a lower value. + Zero means that the server does not impose any specific limit + but may reject very large frames if it cannot allocate resources for them. + heartbeat: int, the delay, in seconds, of the connection heartbeat that the server wants. + Zero means the server does not want a heartbeat. + loop: Asyncio.Eventloop: specify the eventloop to use. + client_properties: dict, client-props to tune the client identification + """ + self._loop = kwargs.get('loop') or asyncio.get_event_loop() + self._reader = asyncio.StreamReader(loop=self._loop) + super().__init__(self._reader, loop=self._loop) + self._on_error_callback = kwargs.get('on_error') + + self.client_properties = kwargs.get('client_properties', {}) + self.connection_tunning = {} + if 'channel_max' in kwargs: + self.connection_tunning['channel_max'] = kwargs.get('channel_max') + if 'frame_max' in kwargs: + self.connection_tunning['frame_max'] = kwargs.get('frame_max') + if 'heartbeat' in kwargs: + self.connection_tunning['heartbeat'] = kwargs.get('heartbeat') + + self.connecting = asyncio.Future(loop=self._loop) + self.connection_closed = asyncio.Event(loop=self._loop) + self.stop_now = asyncio.Future(loop=self._loop) + self.state = CONNECTING + self.version_major = None + self.version_minor = None + self.server_properties = None + self.server_mechanisms = None + self.server_locales = None + self.worker = None + self.server_heartbeat = None + self._heartbeat_timer_recv = None + self._heartbeat_timer_send = None + self._heartbeat_trigger_send = asyncio.Event(loop=self._loop) + self._heartbeat_worker = None + self.channels = {} + self.server_frame_max = None + self.server_channel_max = None + self.channels_ids_ceil = 0 + self.channels_ids_free = set() + self._drain_lock = asyncio.Lock(loop=self._loop) + + def connection_made(self, transport): + super().connection_made(transport) + self._stream_writer = _StreamWriter(transport, self, self._stream_reader, self._loop) + + def eof_received(self): + super().eof_received() + # Python 3.5+ started returning True here to keep the transport open. + # We really couldn't care less so keep the behavior from 3.4 to make + # sure connection_lost() is called. + return False + + def connection_lost(self, exc): + if exc is not None: + logger.warning("Connection lost exc=%r", exc) + self.connection_closed.set() + self.state = CLOSED + self._close_channels(exception=exc) + self._heartbeat_stop() + super().connection_lost(exc) + + def data_received(self, data): + self._heartbeat_timer_recv_reset() + super().data_received(data) + + async def ensure_open(self): + # Raise a suitable exception if the connection isn't open. + # Handle cases from the most common to the least common. + + if self.state == OPEN: + return + + if self.state == CLOSED: + raise exceptions.AmqpClosedConnection() + + # If the closing handshake is in progress, let it complete. + if self.state == CLOSING: + await self.wait_closed() + raise exceptions.AmqpClosedConnection() + + # Control may only reach this point in buggy third-party subclasses. + assert self.state == CONNECTING + raise exceptions.AioamqpException("connection isn't established yet.") + + async def _drain(self): + async with self._drain_lock: + # drain() cannot be called concurrently by multiple coroutines: + # http://bugs.python.org/issue29930. Remove this lock when no + # version of Python where this bugs exists is supported anymore. + await self._stream_writer.drain() + + async def _write_frame(self, channel_id, request, drain=True): + amqp_frame.write(self._stream_writer, channel_id, request) + if drain: + await self._drain() + + async def close(self, no_wait=False, timeout=None): + """Close connection (and all channels)""" + await self.ensure_open() + self.state = CLOSING + request = pamqp.specification.Connection.Close( + reply_code=0, + reply_text='', + class_id=0, + method_id=0 + ) + + await self._write_frame(0, request) + if not no_wait: + await self.wait_closed(timeout=timeout) + + async def wait_closed(self, timeout=None): + await asyncio.wait_for(self.connection_closed.wait(), timeout=timeout, loop=self._loop) + if self._heartbeat_worker is not None: + try: + await asyncio.wait_for(self._heartbeat_worker, timeout=timeout, loop=self._loop) + except asyncio.CancelledError: + pass + + async def close_ok(self, frame): + self._stream_writer.close() + + async def start_connection(self, host, port, login, password, virtualhost, ssl=False, + login_method='PLAIN', insist=False): + """Initiate a connection at the protocol level + We send `PROTOCOL_HEADER' + """ + + if login_method != 'PLAIN': + logger.warning('only PLAIN login_method is supported, falling back to AMQPLAIN') + + self._stream_writer.write(amqp_constants.PROTOCOL_HEADER) + + # Wait 'start' method from the server + await self.dispatch_frame() + + client_properties = { + 'capabilities': { + 'consumer_cancel_notify': True, + 'connection.blocked': False, + }, + 'copyright': 'BSD', + 'product': version.__package__, + 'product_version': version.__version__, + } + client_properties.update(self.client_properties) + auth = { + 'LOGIN': login, + 'PASSWORD': password, + } + + # waiting reply start with credentions and co + await self.start_ok(client_properties, 'PLAIN', auth, self.server_locales) + + # wait for a "tune" reponse + await self.dispatch_frame() + + tune_ok = { + 'channel_max': self.connection_tunning.get('channel_max', self.server_channel_max), + 'frame_max': self.connection_tunning.get('frame_max', self.server_frame_max), + 'heartbeat': self.connection_tunning.get('heartbeat', self.server_heartbeat), + } + # "tune" the connexion with max channel, max frame, heartbeat + await self.tune_ok(**tune_ok) + + # update connection tunning values + self.server_frame_max = tune_ok['frame_max'] + self.server_channel_max = tune_ok['channel_max'] + self.server_heartbeat = tune_ok['heartbeat'] + + if self.server_heartbeat > 0: + self._heartbeat_timer_recv_reset() + self._heartbeat_timer_send_reset() + + # open a virtualhost + await self.open(virtualhost, capabilities='', insist=insist) + + # wait for open-ok + channel, frame = await self.get_frame() + await self.dispatch_frame(channel, frame) + + await self.ensure_open() + # for now, we read server's responses asynchronously + self.worker = asyncio.ensure_future(self.run(), loop=self._loop) + + async def get_frame(self): + """Read the frame, and only decode its header + + """ + return await amqp_frame.read(self._stream_reader) + + async def dispatch_frame(self, frame_channel=None, frame=None): + """Dispatch the received frame to the corresponding handler""" + + method_dispatch = { + pamqp.specification.Connection.Close.name: self.server_close, + pamqp.specification.Connection.CloseOk.name: self.close_ok, + pamqp.specification.Connection.Tune.name: self.tune, + pamqp.specification.Connection.Start.name: self.start, + pamqp.specification.Connection.OpenOk.name: self.open_ok, + } + if frame_channel is None and frame is None: + frame_channel, frame = await self.get_frame() + + if isinstance(frame, pamqp.heartbeat.Heartbeat): + return + + if frame_channel != 0: + channel = self.channels.get(frame_channel) + if channel is not None: + await channel.dispatch_frame(frame) + else: + logger.info("Unknown channel %s", frame_channel) + return + + if frame.name not in method_dispatch: + logger.info("frame %s is not handled", frame.name) + return + await method_dispatch[frame.name](frame) + + def release_channel_id(self, channel_id): + """Called from the channel instance, it relase a previously used + channel_id + """ + self.channels_ids_free.add(channel_id) + + @property + def channels_ids_count(self): + return self.channels_ids_ceil - len(self.channels_ids_free) + + def _close_channels(self, reply_code=None, reply_text=None, exception=None): + """Cleanly close channels + + Args: + reply_code: int, the amqp error code + reply_text: str, the text associated to the error_code + exc: the exception responsible of this error + + """ + if exception is None: + exception = exceptions.ChannelClosed(reply_code, reply_text) + + if self._on_error_callback: + if asyncio.iscoroutinefunction(self._on_error_callback): + asyncio.ensure_future(self._on_error_callback(exception), loop=self._loop) + else: + self._on_error_callback(exceptions.ChannelClosed(exception)) + + for channel in self.channels.values(): + channel.connection_closed(reply_code, reply_text, exception) + + async def run(self): + while not self.stop_now.done(): + try: + await self.dispatch_frame() + except exceptions.AmqpClosedConnection as exc: + logger.info("Close connection") + self.stop_now.set_result(None) + + self._close_channels(exception=exc) + except Exception: # pylint: disable=broad-except + logger.exception('error on dispatch') + + async def heartbeat(self): + """ deprecated heartbeat coroutine + + This coroutine is now a no-op as the heartbeat is handled directly by + the rest of the AmqpProtocol class. This is kept around for backwards + compatibility purposes only. + """ + await self.stop_now + + async def send_heartbeat(self): + """Sends an heartbeat message. + It can be an ack for the server or the client willing to check for the + connexion timeout + """ + request = pamqp.heartbeat.Heartbeat() + await self._write_frame(0, request) + + def _heartbeat_timer_recv_timeout(self): + # 4.2.7 If a peer detects no incoming traffic (i.e. received octets) for + # two heartbeat intervals or longer, it should close the connection + # without following the Connection.Close/Close-Ok handshaking, and log + # an error. + # TODO(rcardona) raise a "timeout" exception somewhere + self._stream_writer.close() + + def _heartbeat_timer_recv_reset(self): + if self.server_heartbeat is None: + return + if self._heartbeat_timer_recv is not None: + self._heartbeat_timer_recv.cancel() + self._heartbeat_timer_recv = self._loop.call_later( + self.server_heartbeat * 2, + self._heartbeat_timer_recv_timeout) + + def _heartbeat_timer_send_reset(self): + if self.server_heartbeat is None: + return + if self._heartbeat_timer_send is not None: + self._heartbeat_timer_send.cancel() + self._heartbeat_timer_send = self._loop.call_later( + self.server_heartbeat, + self._heartbeat_trigger_send.set) + if self._heartbeat_worker is None: + self._heartbeat_worker = asyncio.ensure_future(self._heartbeat(), loop=self._loop) + + def _heartbeat_stop(self): + self.server_heartbeat = None + if self._heartbeat_timer_recv is not None: + self._heartbeat_timer_recv.cancel() + if self._heartbeat_timer_send is not None: + self._heartbeat_timer_send.cancel() + if self._heartbeat_worker is not None: + self._heartbeat_worker.cancel() + + async def _heartbeat(self): + while self.state != CLOSED: + await self._heartbeat_trigger_send.wait() + self._heartbeat_trigger_send.clear() + await self.send_heartbeat() + + # Amqp specific methods + async def start(self, frame): + """Method sent from the server to begin a new connection""" + self.version_major = frame.version_major + self.version_minor = frame.version_minor + self.server_properties = frame.server_properties + self.server_mechanisms = frame.mechanisms + self.server_locales = frame.locales + + async def start_ok(self, client_properties, mechanism, auth, locale): + def credentials(): + return '\0{LOGIN}\0{PASSWORD}'.format(**auth) + + request = pamqp.specification.Connection.StartOk( + client_properties=client_properties, + mechanism=mechanism, + locale=locale, + response=credentials() + ) + await self._write_frame(0, request) + + async def server_close(self, frame): + """The server is closing the connection""" + self.state = CLOSING + reply_code = frame.reply_code + reply_text = frame.reply_text + class_id = frame.class_id + method_id = frame.method_id + logger.warning("Server closed connection: %s, code=%s, class_id=%s, method_id=%s", + reply_text, reply_code, class_id, method_id) + self._close_channels(reply_code, reply_text) + await self._close_ok() + self._stream_writer.close() + + async def _close_ok(self): + request = pamqp.specification.Connection.CloseOk() + await self._write_frame(0, request) + + async def tune(self, frame): + self.server_channel_max = frame.channel_max + self.server_frame_max = frame.frame_max + self.server_heartbeat = frame.heartbeat + + async def tune_ok(self, channel_max, frame_max, heartbeat): + request = pamqp.specification.Connection.TuneOk( + channel_max, frame_max, heartbeat + ) + await self._write_frame(0, request) + + async def secure_ok(self, login_response): + pass + + async def open(self, virtual_host, capabilities='', insist=False): + """Open connection to virtual host.""" + request = pamqp.specification.Connection.Open( + virtual_host, capabilities, insist + ) + await self._write_frame(0, request) + + async def open_ok(self, frame): + self.state = OPEN + logger.info("Recv open ok") + + # + ## aioamqp public methods + # + + async def channel(self, **kwargs): + """Factory to create a new channel + + """ + await self.ensure_open() + try: + channel_id = self.channels_ids_free.pop() + except KeyError: + assert self.server_channel_max is not None, 'connection channel-max tuning not performed' + # channel-max = 0 means no limit + if self.server_channel_max and self.channels_ids_ceil > self.server_channel_max: + raise exceptions.NoChannelAvailable() + self.channels_ids_ceil += 1 + channel_id = self.channels_ids_ceil + channel = self.CHANNEL_FACTORY(self, channel_id, **kwargs) + self.channels[channel_id] = channel + await channel.open() + return channel diff --git a/asyncamqp/__init__.py b/asyncamqp/__init__.py index 6a95bd1..1e84f5a 100644 --- a/asyncamqp/__init__.py +++ b/asyncamqp/__init__.py @@ -5,9 +5,6 @@ from .exceptions import * # pylint: disable=wildcard-import # noqa: F401,F403 from .protocol import AmqpProtocol # noqa: F401 -from ._version import __version__ # noqa: F401 -from ._version import __packagename__ # noqa: F401 - from . import protocol connect_amqp = protocol.connect_amqp diff --git a/asyncamqp/channel.py b/asyncamqp/channel.py index 85c027d..97880b3 100644 --- a/asyncamqp/channel.py +++ b/asyncamqp/channel.py @@ -13,6 +13,7 @@ from . import constants as amqp_constants from . import frame as amqp_frame from . import exceptions +from . import properties as amqp_properties from .envelope import Envelope, ReturnEnvelope from .future import Future from .exceptions import AmqpClosedConnection @@ -157,62 +158,40 @@ async def connection_closed(self, server_code=None, server_reason=None, exceptio async def dispatch_frame(self, frame): methods = { - (amqp_constants.CLASS_CHANNEL, amqp_constants.CHANNEL_OPEN_OK): - self.open_ok, - (amqp_constants.CLASS_CHANNEL, amqp_constants.CHANNEL_FLOW_OK): - self.flow_ok, - (amqp_constants.CLASS_CHANNEL, amqp_constants.CHANNEL_CLOSE_OK): - self.close_ok, - (amqp_constants.CLASS_CHANNEL, amqp_constants.CHANNEL_CLOSE): - self.server_channel_close, - (amqp_constants.CLASS_EXCHANGE, amqp_constants.EXCHANGE_DECLARE_OK): - self.exchange_declare_ok, - (amqp_constants.CLASS_EXCHANGE, amqp_constants.EXCHANGE_BIND_OK): - self.exchange_bind_ok, - (amqp_constants.CLASS_EXCHANGE, amqp_constants.EXCHANGE_UNBIND_OK): - self.exchange_unbind_ok, - (amqp_constants.CLASS_EXCHANGE, amqp_constants.EXCHANGE_DELETE_OK): - self.exchange_delete_ok, - (amqp_constants.CLASS_QUEUE, amqp_constants.QUEUE_DECLARE_OK): - self.queue_declare_ok, - (amqp_constants.CLASS_QUEUE, amqp_constants.QUEUE_DELETE_OK): - self.queue_delete_ok, - (amqp_constants.CLASS_QUEUE, amqp_constants.QUEUE_BIND_OK): - self.queue_bind_ok, - (amqp_constants.CLASS_QUEUE, amqp_constants.QUEUE_UNBIND_OK): - self.queue_unbind_ok, - (amqp_constants.CLASS_QUEUE, amqp_constants.QUEUE_PURGE_OK): - self.queue_purge_ok, - (amqp_constants.CLASS_BASIC, amqp_constants.BASIC_QOS_OK): - self.basic_qos_ok, - (amqp_constants.CLASS_BASIC, amqp_constants.BASIC_CONSUME_OK): - self.basic_consume_ok, - (amqp_constants.CLASS_BASIC, amqp_constants.BASIC_CANCEL_OK): - self.basic_cancel_ok, - (amqp_constants.CLASS_BASIC, amqp_constants.BASIC_GET_OK): - self.basic_get_ok, - (amqp_constants.CLASS_BASIC, amqp_constants.BASIC_GET_EMPTY): - self.basic_get_empty, - (amqp_constants.CLASS_BASIC, amqp_constants.BASIC_DELIVER): - self.basic_deliver, - (amqp_constants.CLASS_BASIC, amqp_constants.BASIC_CANCEL): - self.server_basic_cancel, - (amqp_constants.CLASS_BASIC, amqp_constants.BASIC_ACK): - self.basic_server_ack, - (amqp_constants.CLASS_BASIC, amqp_constants.BASIC_NACK): - self.basic_server_nack, - (amqp_constants.CLASS_BASIC, amqp_constants.BASIC_RECOVER_OK): - self.basic_recover_ok, - (amqp_constants.CLASS_BASIC, amqp_constants.BASIC_RETURN): - self.basic_return, - (amqp_constants.CLASS_CONFIRM, amqp_constants.CONFIRM_SELECT_OK): - self.confirm_select_ok, + pamqp.specification.Channel.OpenOk.name: self.open_ok, + pamqp.specification.Channel.FlowOk.name: self.flow_ok, + pamqp.specification.Channel.CloseOk.name: self.close_ok, + pamqp.specification.Channel.Close.name: self.server_channel_close, + + pamqp.specification.Exchange.DeclareOk.name: self.exchange_declare_ok, + pamqp.specification.Exchange.BindOk.name: self.exchange_bind_ok, + pamqp.specification.Exchange.UnbindOk.name: self.exchange_unbind_ok, + pamqp.specification.Exchange.DeleteOk.name: self.exchange_delete_ok, + + pamqp.specification.Queue.DeclareOk.name: self.queue_declare_ok, + pamqp.specification.Queue.DeleteOk.name: self.queue_delete_ok, + pamqp.specification.Queue.BindOk.name: self.queue_bind_ok, + pamqp.specification.Queue.UnbindOk.name: self.queue_unbind_ok, + pamqp.specification.Queue.PurgeOk.name: self.queue_purge_ok, + + pamqp.specification.Basic.QosOk.name: self.basic_qos_ok, + pamqp.specification.Basic.ConsumeOk.name: self.basic_consume_ok, + pamqp.specification.Basic.CancelOk.name: self.basic_cancel_ok, + pamqp.specification.Basic.GetOk.name: self.basic_get_ok, + pamqp.specification.Basic.GetEmpty.name: self.basic_get_empty, + pamqp.specification.Basic.Deliver.name: self.basic_deliver, + pamqp.specification.Basic.Cancel.name: self.server_basic_cancel, + pamqp.specification.Basic.Ack.name: self.basic_server_ack, + pamqp.specification.Basic.Nack.name: self.basic_server_nack, + pamqp.specification.Basic.RecoverOk.name: self.basic_recover_ok, + pamqp.specification.Basic.Return.name: self.basic_return, + + pamqp.specification.Confirm.SelectOk.name: self.confirm_select_ok, } - if (frame.class_id, frame.method_id) not in methods: - raise NotImplementedError( - "Frame (%s, %s) is not implemented" % (frame.class_id, frame.method_id) - ) + if frame.name not in methods: + raise NotImplementedError("Frame %s is not implemented" % frame.name) + await methods[(frame.class_id, frame.method_id)](frame) async def _write_frame(self, frame, request, check_open=True, drain=True): @@ -237,15 +216,16 @@ async def _write_frame_awaiting_response( await self._write_frame(frame, request, check_open=check_open, drain=drain) return None - f = self._set_waiter(waiter_id) - try: - await self._write_frame(frame, request, check_open=check_open, drain=drain) - except BaseException as exc: - self._get_waiter(waiter_id) - await f.cancel() - raise - res = await f() - return res + async with self._write_lock: + f = self._set_waiter(waiter_id) + try: + await self._write_frame(frame, request, check_open=check_open, drain=drain) + except BaseException as exc: + self._get_waiter(waiter_id) + await f.cancel() + raise + res = await f() + return res # # Channel class implementation @@ -253,13 +233,10 @@ async def _write_frame_awaiting_response( async def open(self): """Open the channel on the server.""" - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id) - frame.declare_method(amqp_constants.CLASS_CHANNEL, amqp_constants.CHANNEL_OPEN) - request = amqp_frame.AmqpEncoder() - request.write_shortstr('') + request = pamqp.specification.Channel.Open() return ( await self._write_frame_awaiting_response( - 'open', frame, request, no_wait=False, check_open=False + 'open', self.channel_id, request, no_wait=False, check_open=False ) ) @@ -276,19 +253,8 @@ async def close(self, reply_code=0, reply_text="Normal Shutdown"): await self.close_event.set() if self._q is not None: self._q.put_nowait(None) - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id) - frame.declare_method(amqp_constants.CLASS_CHANNEL, amqp_constants.CHANNEL_CLOSE) - request = amqp_frame.AmqpEncoder() - request.write_short(reply_code) - request.write_shortstr(reply_text) - request.write_short(0) - request.write_short(0) - async with self._write_lock: - return ( - await self._write_frame_awaiting_response( - 'close', frame, request, no_wait=False, check_open=False - ) - ) + request = pamqp.specification.Channel.Close(reply_code, reply_text, class_id=0, method_id=0) + return await self._write_frame_awaiting_response('close', self.channel_id, request, no_wait=False, check_open=False) async def close_ok(self, frame): try: @@ -301,11 +267,9 @@ async def close_ok(self, frame): self.protocol.release_channel_id(self.channel_id) async def _send_channel_close_ok(self): - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id) - frame.declare_method(amqp_constants.CLASS_CHANNEL, amqp_constants.CHANNEL_CLOSE_OK) - request = amqp_frame.AmqpEncoder() + request = pamqp.specification.Channel.CloseOk() # intentionally not locked - await self._write_frame(frame, request) + await self._write_frame(channel_id, request) async def server_channel_close(self, frame): try: @@ -313,31 +277,21 @@ async def server_channel_close(self, frame): except exceptions.ChannelClosed: pass results = { - 'reply_code': frame.payload_decoder.read_short(), - 'reply_text': frame.payload_decoder.read_shortstr(), - 'class_id': frame.payload_decoder.read_short(), - 'method_id': frame.payload_decoder.read_short(), + 'reply_code': frame.reply_code, + 'reply_text': frame.reply_text, + 'class_id': frame.class_id, + 'method_id': frame.method_id, } await self.connection_closed(results['reply_code'], results['reply_text']) async def flow(self, active): - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id) - frame.declare_method(amqp_constants.CLASS_CHANNEL, amqp_constants.CHANNEL_FLOW) - request = amqp_frame.AmqpEncoder() - request.write_bits(active) - async with self._write_lock: - return ( - await self._write_frame_awaiting_response( - 'flow', frame, request, no_wait=False, check_open=False - ) - ) + request = pamqp.specification.Channel.Flow(active) + return await self._write_frame_awaiting_response('flow', self.channel_id, request, no_wait=False, check_open=False) async def flow_ok(self, frame): - decoder = amqp_frame.AmqpDecoder(frame.payload) - active = bool(decoder.read_octet()) self.close_event = anyio.create_event() fut = self._get_waiter('flow') - await fut.set_result({'active': active}) + await fut.set_result({'active': frame.active}) logger.debug("Flow ok") @@ -355,23 +309,17 @@ async def exchange_declare( no_wait=False, arguments=None ): - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id) - frame.declare_method(amqp_constants.CLASS_EXCHANGE, amqp_constants.EXCHANGE_DECLARE) - request = amqp_frame.AmqpEncoder() - # short reserved-1 - request.write_short(0) - request.write_shortstr(exchange_name) - request.write_shortstr(type_name) - - internal = False # internal: deprecated - request.write_bits(passive, durable, auto_delete, internal, no_wait) - request.write_table(arguments) + request = pamqp.specification.Exchange.Declare( + exchange=exchange_name, + exchange_type=type_name, + passive=passive, + durable=durable, + auto_delete=auto_delete, + nowait=no_wait, + arguments=arguments + ) - async with self._write_lock: - return ( - await - self._write_frame_awaiting_response('exchange_declare', frame, request, no_wait) - ) + return await self._write_frame_awaiting_response('exchange_declare', self.channel_id, request, no_wait) async def exchange_declare_ok(self, frame): future = self._get_waiter('exchange_declare') @@ -380,19 +328,9 @@ async def exchange_declare_ok(self, frame): return future async def exchange_delete(self, exchange_name, if_unused=False, no_wait=False): - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id) - frame.declare_method(amqp_constants.CLASS_EXCHANGE, amqp_constants.EXCHANGE_DELETE) - request = amqp_frame.AmqpEncoder() - # short reserved-1 - request.write_short(0) - request.write_shortstr(exchange_name) - request.write_bits(if_unused, no_wait) + request = pamqp.specification.Exchange.Delete(exchange=exchange_name, if_unused=if_unused, nowait=no_wait) - async with self._write_lock: - return ( - await - self._write_frame_awaiting_response('exchange_delete', frame, request, no_wait) - ) + return await self._write_frame_awaiting_response('exchange_delete', frame, request, no_wait) async def exchange_delete_ok(self, frame): future = self._get_waiter('exchange_delete') @@ -404,22 +342,14 @@ async def exchange_bind( ): if arguments is None: arguments = {} - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id) - frame.declare_method(amqp_constants.CLASS_EXCHANGE, amqp_constants.EXCHANGE_BIND) - - request = amqp_frame.AmqpEncoder() - request.write_short(0) # reserved - request.write_shortstr(exchange_destination) - request.write_shortstr(exchange_source) - request.write_shortstr(routing_key) - - request.write_bits(no_wait) - request.write_table(arguments) - async with self._write_lock: - return ( - await - self._write_frame_awaiting_response('exchange_bind', frame, request, no_wait) - ) + request = pamqp.specification.Exchange.Bind( + destination=exchange_destination, + source=exchange_source, + routing_key=routing_key, + nowait=no_wait, + arguments=arguments + ) + return await self._write_frame_awaiting_response('exchange_bind', self.channel_id, request, no_wait) async def exchange_bind_ok(self, frame): future = self._get_waiter('exchange_bind') @@ -431,22 +361,15 @@ async def exchange_unbind( ): if arguments is None: arguments = {} - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id) - frame.declare_method(amqp_constants.EXCHANGE_UNBIND, amqp_constants.EXCHANGE_UNBIND) - request = amqp_frame.AmqpEncoder() - request.write_short(0) # reserved - request.write_shortstr(exchange_destination) - request.write_shortstr(exchange_source) - request.write_shortstr(routing_key) - - request.write_bits(no_wait) - request.write_table(arguments) - async with self._write_lock: - return ( - await - self._write_frame_awaiting_response('exchange_unbind', frame, request, no_wait) - ) + request = pamqp.specification.Exchange.Unbind( + destination=exchange_destination, + source=exchange_source, + routing_key=routing_key, + nowait=no_wait, + arguments=arguments, + ) + return await self._write_frame_awaiting_response('exchange_unbind', self.channel_id, request, no_wait) async def exchange_unbind_ok(self, frame): future = self._get_waiter('exchange_unbind') @@ -494,24 +417,22 @@ async def queue_declare( if not queue_name: queue_name = '' - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id) - frame.declare_method(amqp_constants.CLASS_QUEUE, amqp_constants.QUEUE_DECLARE) - request = amqp_frame.AmqpEncoder() - request.write_short(0) # reserved - request.write_shortstr(queue_name) - request.write_bits(passive, durable, exclusive, auto_delete, no_wait) - request.write_table(arguments) - async with self._write_lock: - return ( - await - self._write_frame_awaiting_response('queue_declare', frame, request, no_wait) - ) + request = pamqp.specification.Queue.Declare( + queue=queue_name, + passive=passive, + durable=durable, + exclusive=exclusive, + auto_delete=auto_delete, + nowait=no_wait, + arguments=arguments + ) + return await self._write_frame_awaiting_response('queue_declare', self.channel_id, request, no_wait) async def queue_declare_ok(self, frame): results = { - 'queue': frame.payload_decoder.read_shortstr(), - 'message_count': frame.payload_decoder.read_long(), - 'consumer_count': frame.payload_decoder.read_long(), + 'queue': frame.queue, + 'message_count': frame.message_count, + 'consumer_count': frame.consumer_count, } future = self._get_waiter('queue_declare') await future.set_result(results) @@ -531,17 +452,13 @@ async def queue_delete(self, queue_name, if_unused=False, if_empty=False, no_wai no_wait: bool, if set, the server will not respond to the method """ - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id) - frame.declare_method(amqp_constants.CLASS_QUEUE, amqp_constants.QUEUE_DELETE) - - request = amqp_frame.AmqpEncoder() - request.write_short(0) # reserved - request.write_shortstr(queue_name) - request.write_bits(if_unused, if_empty, no_wait) - async with self._write_lock: - return ( - await self._write_frame_awaiting_response('queue_delete', frame, request, no_wait) - ) + request = pamqp.specification.Queue.Delete( + queue=queue_name, + if_unused=if_unused, + if_empty=if_empty, + nowait=no_wait + ) + return await self._write_frame_awaiting_response('queue_delete', self.channel_id, request, no_wait) async def queue_delete_ok(self, frame): future = self._get_waiter('queue_delete') @@ -554,21 +471,14 @@ async def queue_bind( """Bind a queue to an exchange.""" if arguments is None: arguments = {} - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id) - frame.declare_method(amqp_constants.CLASS_QUEUE, amqp_constants.QUEUE_BIND) - - request = amqp_frame.AmqpEncoder() - # short reserved-1 - request.write_short(0) - request.write_shortstr(queue_name) - request.write_shortstr(exchange_name) - request.write_shortstr(routing_key) - request.write_octet(int(no_wait)) - request.write_table(arguments) - async with self._write_lock: - return ( - await self._write_frame_awaiting_response('queue_bind', frame, request, no_wait) - ) + request = pamqp.specification.Queue.Bind( + queue=queue_name, + exchange=exchange_name, + routing_key=routing_key, + nowait=no_wait, + arguments=arguments + ) + return await self._write_frame_awaiting_response('queue_bind', self.channel_id, request, no_wait) async def queue_bind_ok(self, frame): future = self._get_waiter('queue_bind') @@ -578,21 +488,13 @@ async def queue_bind_ok(self, frame): async def queue_unbind(self, queue_name, exchange_name, routing_key, arguments=None): if arguments is None: arguments = {} - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id) - frame.declare_method(amqp_constants.CLASS_QUEUE, amqp_constants.QUEUE_UNBIND) - - request = amqp_frame.AmqpEncoder() - # short reserved-1 - request.write_short(0) - request.write_shortstr(queue_name) - request.write_shortstr(exchange_name) - request.write_shortstr(routing_key) - request.write_table(arguments) - async with self._write_lock: - return ( - await - self._write_frame_awaiting_response('queue_unbind', frame, request, no_wait=False) - ) + request = pamqp.specification.Queue.Unbind( + queue=queue_name, + exchange=exchange_name, + routing_key=routing_key, + arguments=arguments + ) + return await self._write_frame_awaiting_response('queue_unbind', self.channel_id, request, no_wait=False) async def queue_unbind_ok(self, frame): future = self._get_waiter('queue_unbind') @@ -600,26 +502,14 @@ async def queue_unbind_ok(self, frame): logger.debug("Queue unbound") async def queue_purge(self, queue_name, no_wait=False): - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id) - frame.declare_method(amqp_constants.CLASS_QUEUE, amqp_constants.QUEUE_PURGE) - - request = amqp_frame.AmqpEncoder() - # short reserved-1 - request.write_short(0) - request.write_shortstr(queue_name) - request.write_octet(int(no_wait)) - async with self._write_lock: - return ( - await self._write_frame_awaiting_response( - 'queue_purge', frame, request, no_wait=no_wait - ) - ) + request = pamqp.specification.Queue.Purge( + queue=queue_name, nowait=no_wait + ) + return await self._write_frame_awaiting_response('queue_purge', self.channel_id, request, no_wait=no_wait) async def queue_purge_ok(self, frame): - decoder = amqp_frame.AmqpDecoder(frame.payload) - message_count = decoder.read_long() future = self._get_waiter('queue_purge') - await future.set_result({'message_count': message_count}) + await future.set_result({'message_count': frame.message_count}) # # Basic class implementation @@ -634,42 +524,34 @@ async def basic_publish( mandatory=False, immediate=False ): - assert payload, "Payload cannot be empty" - async with self._write_lock: - method_frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id) - method_frame.declare_method(amqp_constants.CLASS_BASIC, amqp_constants.BASIC_PUBLISH) - method_request = amqp_frame.AmqpEncoder() - method_request.write_short(0) - method_request.write_shortstr(exchange_name) - method_request.write_shortstr(routing_key) - method_request.write_bits(mandatory, immediate) - await self._write_frame(method_frame, method_request, drain=False) - - header_frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_HEADER, self.channel_id) - header_frame.declare_class(amqp_constants.CLASS_BASIC) - header_frame.set_body_size(len(payload)) - encoder = amqp_frame.AmqpEncoder() - encoder.write_message_properties(properties) - await self._write_frame(header_frame, encoder, drain=False) + if properties is None: + properties = {} + + method_request = pamqp.specification.Basic.Publish( + exchange=exchange_name, + routing_key=routing_key, + mandatory=mandatory, + immediate=immediate + ) - # split the payload + await self._write_frame(self.channel_id, method_request, drain=False) + + header_request = pamqp.header.ContentHeader( + body_size=len(payload), + properties=pamqp.specification.Basic.Properties(**properties) + ) + await self._write_frame(self.channel_id, header_request, drain=False) frame_max = self.protocol.server_frame_max or len(payload) for chunk in (payload[0 + i:frame_max + i] for i in range(0, len(payload), frame_max)): - content_frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_BODY, self.channel_id) - content_frame.declare_class(amqp_constants.CLASS_BASIC) - encoder = amqp_frame.AmqpEncoder() - if isinstance(chunk, str): - encoder.payload.write(chunk.encode()) - else: - encoder.payload.write(chunk) - await self._write_frame(content_frame, encoder, drain=False) + content_request = pamqp.body.ContentBody(chunk) + await self._write_frame(self.channel_id, content_request, drain=False) await self.protocol._drain() - async def basic_qos(self, prefetch_size=0, prefetch_count=0, connection_global=None): + async def basic_qos(self, prefetch_size=0, prefetch_count=0, connection_global=False): """Specifies quality of service. Args: @@ -688,18 +570,10 @@ async def basic_qos(self, prefetch_size=0, prefetch_count=0, connection_global=N per-consumer channel; and global=true to mean that the QoS settings should apply per-channel. """ - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id) - frame.declare_method(amqp_constants.CLASS_BASIC, amqp_constants.BASIC_QOS) - request = amqp_frame.AmqpEncoder() - request.write_long(prefetch_size) - request.write_short(prefetch_count) - request.write_bits(connection_global) - - async with self._write_lock: - return ( - await - self._write_frame_awaiting_response('basic_qos', frame, request, no_wait=False) - ) + request = pamqp.specification.Basic.Qos( + prefetch_size, prefetch_count, connection_global + ) + return await self._write_frame_awaiting_response('basic_qos', self.channel_id, request, no_wait=False) async def basic_qos_ok(self, frame): future = self._get_waiter('basic_qos') @@ -708,8 +582,7 @@ async def basic_qos_ok(self, frame): async def basic_server_nack(self, frame, delivery_tag=None): if delivery_tag is None: - decoder = amqp_frame.AmqpDecoder(frame.payload) - delivery_tag = decoder.read_long_long() + delivery_tag = frame.delivery_tag fut = self._get_waiter('basic_server_ack_{}'.format(delivery_tag)) logger.debug('Received nack for delivery tag %r', delivery_tag) await fut.set_exception(exceptions.PublishFailed(delivery_tag)) @@ -728,7 +601,7 @@ def new_consumer( Usage:: - async with chan.new_consumer(callback, queue_name="my_queue") \ + async with chan.new_consumer(queue_name="my_queue") \ as listener: async for body, envelope, properties in listener: await process_message(body, envelope, properties) @@ -838,22 +711,20 @@ async def basic_consume( if arguments is None: arguments = {} - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id) - frame.declare_method(amqp_constants.CLASS_BASIC, amqp_constants.BASIC_CONSUME) - request = amqp_frame.AmqpEncoder() - request.write_short(0) - request.write_shortstr(queue_name) - request.write_shortstr(consumer_tag) - request.write_bits(no_local, no_ack, exclusive, no_wait) - request.write_table(arguments) + request = pamqp.specification.Basic.Consume( + queue=queue_name, + consumer_tag=consumer_tag, + no_local=no_local, + no_ack=no_ack, + exclusive=exclusive, + nowait=no_wait, + arguments=arguments + ) self.consumer_callbacks[consumer_tag] = callback self.last_consumer_tag = consumer_tag - async with self._write_lock: - return_value = await self._write_frame_awaiting_response( - 'basic_consume', frame, request, no_wait - ) + return_value = await self._write_frame_awaiting_response('basic_consume', self.channel_id, request, no_wait) if no_wait: return_value = {'consumer_tag': consumer_tag} else: @@ -861,31 +732,29 @@ async def basic_consume( return return_value async def basic_consume_ok(self, frame): - ctag = frame.payload_decoder.read_shortstr() results = { - 'consumer_tag': ctag, + 'consumer_tag': frame.consumer_tag } future = self._get_waiter('basic_consume') await future.set_result(results) self._ctag_events[ctag] = anyio.create_event() async def basic_deliver(self, frame): - response = amqp_frame.AmqpDecoder(frame.payload) - consumer_tag = response.read_shortstr() - delivery_tag = response.read_long_long() - is_redeliver = response.read_bit() - exchange_name = response.read_shortstr() - routing_key = response.read_shortstr() - content_header_frame = await self.protocol.get_frame() + consumer_tag = frame.consumer_tag + delivery_tag = frame.delivery_tag + is_redeliver = frame.redelivered + exchange_name = frame.exchange + routing_key = frame.routing_key + channel, content_header_frame = await self.protocol.get_frame() buffer = io.BytesIO() while (buffer.tell() < content_header_frame.body_size): - content_body_frame = await self.protocol.get_frame() - buffer.write(content_body_frame.payload) + _channel, content_body_frame = await self.protocol.get_frame() + buffer.write(content_body_frame.value) body = buffer.getvalue() envelope = Envelope(consumer_tag, delivery_tag, exchange_name, routing_key, is_redeliver) - properties = content_header_frame.properties + properties = amqp_properties.from_pamqp(content_header_frame.properties) callback = self.consumer_callbacks[consumer_tag] @@ -900,8 +769,8 @@ async def basic_deliver(self, frame): async def server_basic_cancel(self, frame): # https://www.rabbitmq.com/consumer-cancel.html - consumer_tag = frame.payload_decoder.read_shortstr() - _no_wait = frame.payload_decoder.read_bit() # noqa: F841 + consumer_tag = frame.consumer_tag + _no_wait = frame.nowait self.cancelled_consumers.add(consumer_tag) logger.info("consume cancelled received") callback = self.consumer_callbacks.get(consumer_tag, None) @@ -914,79 +783,60 @@ async def server_basic_cancel(self, frame): async def basic_cancel(self, consumer_tag, no_wait=False): - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id) - frame.declare_method(amqp_constants.CLASS_BASIC, amqp_constants.BASIC_CANCEL) - request = amqp_frame.AmqpEncoder() - request.write_shortstr(consumer_tag) - request.write_bits(no_wait) - async with self._write_lock: - return ( - await self._write_frame_awaiting_response( - 'basic_cancel', frame, request, no_wait=no_wait - ) - ) + request = pamqp.specification.Basic.Cancel(consumer_tag, no_wait) + return await self._write_frame_awaiting_response('basic_cancel', self.channel_id, request, no_wait=no_wait) async def basic_cancel_ok(self, frame): results = { - 'consumer_tag': frame.payload_decoder.read_shortstr(), + 'consumer_tag': frame.consumer_tag, } future = self._get_waiter('basic_cancel') await future.set_result(results) logger.debug("Cancel ok") async def basic_return(self, frame): - response = amqp_frame.AmqpDecoder(frame.payload) - reply_code = response.read_short() - reply_text = response.read_shortstr() - exchange_name = response.read_shortstr() - routing_key = response.read_shortstr() - content_header_frame = await self.protocol.get_frame() + reply_code = frame.reply_code + reply_text = frame.reply_text + exchange_name = frame.exchange + routing_key = frame.routing_key + channel, content_header_frame = await self.protocol.get_frame() buffer = io.BytesIO() while buffer.tell() < content_header_frame.body_size: - content_body_frame = await self.protocol.get_frame() - buffer.write(content_body_frame.payload) + _channel, content_body_frame = await self.protocol.get_frame() + buffer.write(content_body_frame.value) body = buffer.getvalue() envelope = ReturnEnvelope(reply_code, reply_text, exchange_name, routing_key) - properties = content_header_frame.properties + properties = amqp_properties.from_pamqp(content_header_frame.properties) if self._q is None: - # they have set mandatory bit, but havent added a callback + # they have set mandatory bit, but aren't reading logger.warning("You don't iterate the channel for returned messages!") else: await self._q.put((body, envelope, properties)) async def basic_get(self, queue_name='', no_ack=False): - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id) - frame.declare_method(amqp_constants.CLASS_BASIC, amqp_constants.BASIC_GET) - request = amqp_frame.AmqpEncoder() - request.write_short(0) - request.write_shortstr(queue_name) - request.write_bits(no_ack) - async with self._write_lock: - return ( - await - self._write_frame_awaiting_response('basic_get', frame, request, no_wait=False) - ) + request = pamqp.specification.Basic.Get(queue=queue_name, no_ack=no_ack) + return await self._write_frame_awaiting_response('basic_get', self.channel_id, request, no_wait=False) async def basic_get_ok(self, frame): - data = {} - decoder = amqp_frame.AmqpDecoder(frame.payload) - data['delivery_tag'] = decoder.read_long_long() - data['redelivered'] = bool(decoder.read_octet()) - data['exchange_name'] = decoder.read_shortstr() - data['routing_key'] = decoder.read_shortstr() - data['message_count'] = decoder.read_long() - content_header_frame = await self.protocol.get_frame() + data = { + 'delivery_tag': frame.delivery_tag, + 'redelivered': frame.redelivered, + 'exchange_name': frame.exchange, + 'routing_key': frame.routing_key, + 'message_count': frame.message_count, + } + _channel, content_header_frame = await self.protocol.get_frame() buffer = io.BytesIO() while (buffer.tell() < content_header_frame.body_size): - content_body_frame = await self.protocol.get_frame() + _channel, content_body_frame = await self.protocol.get_frame() buffer.write(content_body_frame.payload) data['message'] = buffer.getvalue() - data['properties'] = content_header_frame.properties + data['properties'] = amqp_properties.from_pamqp(content_header_frame.properties) future = self._get_waiter('basic_get') await future.set_result(data) @@ -995,58 +845,35 @@ async def basic_get_empty(self, frame): await future.set_exception(exceptions.EmptyQueue) async def basic_client_ack(self, delivery_tag, multiple=False): - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id) - frame.declare_method(amqp_constants.CLASS_BASIC, amqp_constants.BASIC_ACK) - request = amqp_frame.AmqpEncoder() - request.write_long_long(delivery_tag) + request = pamqp.specification.Basic.Ack(delivery_tag, multiple) request.write_bits(multiple) async with self._write_lock: - await self._write_frame(frame, request) + await self._write_frame(self.channel_id, request) async def basic_client_nack(self, delivery_tag, multiple=False, requeue=True): - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id) - frame.declare_method(amqp_constants.CLASS_BASIC, amqp_constants.BASIC_NACK) - request = amqp_frame.AmqpEncoder() - request.write_long_long(delivery_tag) - request.write_bits(multiple, requeue) + request = pamqp.specification.Basic.Nack(delivery_tag, multiple, requeue) async with self._write_lock: - await self._write_frame(frame, request) + await self._write_frame(self.channel_id, request) async def basic_server_ack(self, frame): - decoder = amqp_frame.AmqpDecoder(frame.payload) - delivery_tag = decoder.read_long_long() + delivery_tag = frame.delivery_tag fut = self._get_waiter('basic_server_ack_{}'.format(delivery_tag)) logger.debug('Received ack for delivery tag %s', delivery_tag) await fut.set_result(True) async def basic_reject(self, delivery_tag, requeue=False): - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id) - frame.declare_method(amqp_constants.CLASS_BASIC, amqp_constants.BASIC_REJECT) - request = amqp_frame.AmqpEncoder() - request.write_long_long(delivery_tag) - request.write_bits(requeue) + request = pamqp.specification.Basic.Reject(delivery_tag, requeue) async with self._write_lock: - await self._write_frame(frame, request) + await self._write_frame(self.channel_id, request) async def basic_recover_async(self, requeue=True): - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id) - frame.declare_method(amqp_constants.CLASS_BASIC, amqp_constants.BASIC_RECOVER_ASYNC) - request = amqp_frame.AmqpEncoder() - request.write_bits(requeue) + request = pamqp.specification.Basic.RecoverAsync(requeue) async with self._write_lock: - await self._write_frame(frame, request) + await self._write_frame(self.channel_id, request) async def basic_recover(self, requeue=True): - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id) - frame.declare_method(amqp_constants.CLASS_BASIC, amqp_constants.BASIC_RECOVER) - request = amqp_frame.AmqpEncoder() - request.write_bits(requeue) - async with self._write_lock: - return ( - await self._write_frame_awaiting_response( - 'basic_recover', frame, request, no_wait=False - ) - ) + request = pamqp.specification.Basic.Recover(requeue) + return await self._write_frame_awaiting_response('basic_recover', self.channel_id, request, no_wait=False) async def basic_recover_ok(self, frame): future = self._get_waiter('basic_recover') @@ -1070,44 +897,34 @@ async def publish( mandatory=False, immediate=False ): - assert payload, "Payload cannot be empty" - async with self._write_lock: if self.publisher_confirms: delivery_tag = next(self.delivery_tag_iter) fut = self._set_waiter('basic_server_ack_{}'.format(delivery_tag)) - method_frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id) - method_frame.declare_method(amqp_constants.CLASS_BASIC, amqp_constants.BASIC_PUBLISH) - method_request = amqp_frame.AmqpEncoder() - method_request.write_short(0) - method_request.write_shortstr(exchange_name) - method_request.write_shortstr(routing_key) - method_request.write_bits(mandatory, immediate) - await self._write_frame(method_frame, method_request, drain=False) - - header_frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_HEADER, self.channel_id) - header_frame.declare_class(amqp_constants.CLASS_BASIC) - header_frame.set_body_size(len(payload)) - encoder = amqp_frame.AmqpEncoder() - encoder.write_message_properties(properties) - await self._write_frame(header_frame, encoder, drain=False) + method_request = pamqp.specification.Basic.Publish( + exchange=exchange_name, + routing_key=routing_key, + mandatory=mandatory, + immediate=immediate + ) + await self._write_frame(self.channel_id, method_request, drain=False) + + properties = pamqp.specification.Basic.Properties(**properties) + header_request = pamqp.header.ContentHeader( + body_size=len(payload), properties=properties + ) + + await self._write_frame(self.channel_id, encoder, drain=False) # split the payload frame_max = self.protocol.server_frame_max or len(payload) for chunk in (payload[0 + i:frame_max + i] for i in range(0, len(payload), frame_max)): + content_request = pamqp.body.ContentBody(chunk) + await self._write_frame(self.channel_id, encoder, drain=False) - content_frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_BODY, self.channel_id) - content_frame.declare_class(amqp_constants.CLASS_BASIC) - encoder = amqp_frame.AmqpEncoder() - if isinstance(chunk, str): - encoder.payload.write(chunk.encode()) - else: - encoder.payload.write(chunk) - await self._write_frame(content_frame, encoder, drain=False) - - await self.protocol._drain() + await self.protocol._drain() if self.publisher_confirms: await fut() @@ -1120,11 +937,7 @@ async def confirm_select(self, *, no_wait=False): request = amqp_frame.AmqpEncoder() request.write_shortstr('') - async with self._write_lock: - return ( - await - self._write_frame_awaiting_response('confirm_select', frame, request, no_wait) - ) + return await self._write_frame_awaiting_response('confirm_select', frame, request, no_wait) async def confirm_select_ok(self, frame): self.publisher_confirms = True diff --git a/asyncamqp/constants.py b/asyncamqp/constants.py index f7945b7..5866bd6 100644 --- a/asyncamqp/constants.py +++ b/asyncamqp/constants.py @@ -94,7 +94,7 @@ MESSAGE_PROPERTIES = ( 'content_type', 'content_encoding', 'headers', 'delivery_mode', 'priority', 'correlation_id', - 'reply_to', 'expiration', 'message_id', 'timestamp', 'type', 'user_id', 'app_id', 'cluster_id' + 'reply_to', 'expiration', 'message_id', 'timestamp', 'message_type', 'user_id', 'app_id', 'cluster_id', ) FLAG_CONTENT_TYPE = (1 << 15) diff --git a/asyncamqp/frame.py b/asyncamqp/frame.py index 1157dae..22dd9b3 100644 --- a/asyncamqp/frame.py +++ b/asyncamqp/frame.py @@ -43,523 +43,63 @@ import socket import os import datetime +from anyio.exceptions import ClosedResourceError, IncompleteRead from itertools import count from decimal import Decimal +import pamqp.encode +import pamqp.specification +import pamqp.frame + from . import exceptions from . import constants as amqp_constants from .properties import Properties DUMP_FRAMES = False +async def write(writer, channel, encoder): + """Writes the built frame from the encoder -class AmqpEncoder: - def __init__(self): - self.payload = io.BytesIO() - - def write_table(self, data_dict): - - self.write_long(0) # the table length (set later) - if data_dict: - start = self.payload.tell() - for key, value in data_dict.items(): - self.write_shortstr(key) - self.write_value(value) - table_length = self.payload.tell() - start - self.payload.seek(start - 4) # move before the long - self.write_long(table_length) # and set the table length - self.payload.seek(0, os.SEEK_END) # return at the end - - def write_array(self, value): - array_data = AmqpEncoder() - for item in value: - array_data.write_value(item) - array_data = array_data.payload.getvalue() - self.write_long(len(array_data)) - self.payload.write(array_data) - - def write_value(self, value): - if isinstance(value, (bytes, str)): - self.payload.write(b'S') - self.write_longstr(value) - elif isinstance(value, bool): - self.payload.write(b't') - self.write_bool(value) - elif isinstance(value, dict): - self.payload.write(b'F') - self.write_table(value) - elif isinstance(value, int): - self.payload.write(b'I') - self.write_long(value) - elif isinstance(value, float): - self.payload.write(b'd') - self.write_float(value) - elif isinstance(value, (list, tuple)): - self.payload.write(b'A') - self.write_array(value) - elif isinstance(value, Decimal): - self.payload.write(b'D') - self.write_decimal(value) - elif isinstance(value, datetime.datetime): - self.payload.write(b'T') - self.write_timestamp(value) - elif value is None: - self.payload.write(b'V') - else: - raise Exception("type({}) unsupported".format(type(value))) - - def write_bits(self, *args): - """Write consecutive bools to one byte""" - assert len(args) <= 8, "write_bits can only write 8 bits into one octet, sadly" - byte_value = 0 - - for arg_index, bit in enumerate(args): - if bit: - byte_value |= (1 << arg_index) - - self.write_octet(byte_value) - - def write_bool(self, value): - self.payload.write(struct.pack('?', value)) - - def write_octet(self, octet): - self.payload.write(struct.pack('!B', octet)) - - def write_short(self, short): - self.payload.write(struct.pack('!H', short)) - - def write_long(self, integer): - self.payload.write(struct.pack('!I', integer)) - - def write_long_long(self, longlong): - self.payload.write(struct.pack('!Q', longlong)) - - def write_float(self, value): - self.payload.write(struct.pack('>d', value)) - - def write_decimal(self, value): - sign, digits, exponent = value.as_tuple() - v = 0 - for d in digits: - v = (v * 10) + d - if sign: - v = -v - self.write_octet(-exponent) - self.payload.write(struct.pack('>i', v)) - - def write_timestamp(self, value): - """Write out a Python datetime.datetime object as a 64-bit integer - representing seconds since the Unix epoch. - """ - self.payload.write( - struct.pack('>Q', int(value.replace(tzinfo=datetime.timezone.utc).timestamp())) - ) - - def _write_string(self, string): - if isinstance(string, str): - self.payload.write(string.encode()) - elif isinstance(string, bytes): - self.payload.write(string) - - def write_longstr(self, string): - self.write_long(len(string)) - self._write_string(string) - - def write_shortstr(self, string): - self.write_octet(len(string)) - self._write_string(string) - - def write_message_properties(self, properties): - - properties_flag_value = 0 - if properties is None: - self.write_short(0) - return - - diff = set(properties.keys()) - set(amqp_constants.MESSAGE_PROPERTIES) - if diff: - raise ValueError( - "%s are not properties, valid properties are %s" % - (diff, amqp_constants.MESSAGE_PROPERTIES) - ) - - start = self.payload.tell() # record the position - self.write_short(properties_flag_value) # set the flag later - - content_type = properties.get('content_type') - if content_type: - properties_flag_value |= amqp_constants.FLAG_CONTENT_TYPE - self.write_shortstr(content_type) - content_encoding = properties.get('content_encoding') - if content_encoding: - properties_flag_value |= amqp_constants.FLAG_CONTENT_ENCODING - self.write_shortstr(content_encoding) - headers = properties.get('headers') - if headers is not None: - properties_flag_value |= amqp_constants.FLAG_HEADERS - self.write_table(headers) - delivery_mode = properties.get('delivery_mode') - if delivery_mode is not None: - properties_flag_value |= amqp_constants.FLAG_DELIVERY_MODE - self.write_octet(delivery_mode) - priority = properties.get('priority') - if priority is not None: - properties_flag_value |= amqp_constants.FLAG_PRIORITY - self.write_octet(priority) - correlation_id = properties.get('correlation_id') - if correlation_id: - properties_flag_value |= amqp_constants.FLAG_CORRELATION_ID - self.write_shortstr(correlation_id) - reply_to = properties.get('reply_to') - if reply_to: - properties_flag_value |= amqp_constants.FLAG_REPLY_TO - self.write_shortstr(reply_to) - expiration = properties.get('expiration') - if expiration: - properties_flag_value |= amqp_constants.FLAG_EXPIRATION - self.write_shortstr(expiration) - message_id = properties.get('message_id') - if message_id: - properties_flag_value |= amqp_constants.FLAG_MESSAGE_ID - self.write_shortstr(message_id) - timestamp = properties.get('timestamp') - if timestamp is not None: - properties_flag_value |= amqp_constants.FLAG_TIMESTAMP - self.write_long_long(timestamp) - type_ = properties.get('type') - if type_: - properties_flag_value |= amqp_constants.FLAG_TYPE - self.write_shortstr(type_) - user_id = properties.get('user_id') - if user_id: - properties_flag_value |= amqp_constants.FLAG_USER_ID - self.write_shortstr(user_id) - app_id = properties.get('app_id') - if app_id: - properties_flag_value |= amqp_constants.FLAG_APP_ID - self.write_shortstr(app_id) - cluster_id = properties.get('cluster_id') - if cluster_id: - properties_flag_value |= amqp_constants.FLAG_CLUSTER_ID - self.write_shortstr(cluster_id) - - self.payload.seek(start) # move before the flag - self.write_short(properties_flag_value) # set the flag - self.payload.seek(0, os.SEEK_END) - - -class AmqpDecoder: - def __init__(self, reader): - self.reader = reader - - def read_bit(self): - return bool(self.read_octet()) - - def read_octet(self): - data = self.reader.read(1) - return ord(data) - - def read_signed_octet(self): - data = self.reader.read(1) - return struct.unpack('!b', data)[0] - - def read_short(self): - data = self.reader.read(2) - return struct.unpack('!H', data)[0] - - def read_signed_short(self): - data = self.reader.read(2) - return struct.unpack('!h', data)[0] - - def read_long(self): - data = self.reader.read(4) - return struct.unpack('!I', data)[0] - - def read_signed_long(self): - data = self.reader.read(4) - return struct.unpack('!i', data)[0] - - def read_long_long(self): - data = self.reader.read(8) - return struct.unpack('!Q', data)[0] - - def read_signed_long_long(self): - data = self.reader.read(8) - return struct.unpack('!q', data)[0] - - def read_float(self): - # XXX: This used to read & unpack '!d', which is a double, - # not a shorter float - data = self.reader.read(4) - return struct.unpack('!f', data)[0] - - def read_double(self): - data = self.reader.read(8) - return struct.unpack('!d', data)[0] - - def read_decimal(self): - decimals = self.read_octet() - value = self.read_signed_long() - return Decimal(value) * (Decimal(10)**-decimals) - - def read_shortstr(self): - data = self.reader.read(1) - string_len = struct.unpack('!B', data)[0] - data = self.reader.read(string_len) - return data.decode() - - def read_longstr(self): - string_len = self.read_long() - data = self.reader.read(string_len) - return data.decode() - - def read_timestamp(self): - return datetime.datetime.fromtimestamp(self.read_long_long(), datetime.timezone.utc) - - def read_table(self): - """Reads an AMQP table""" - table_len = self.read_long() - table_data = AmqpDecoder(io.BytesIO(self.reader.read(table_len))) - table = {} - while table_data.reader.tell() < table_len: - var_name = table_data.read_shortstr() - var_value = self.read_table_subitem(table_data) - table[var_name] = var_value - return table - - _table_subitem_reader_map = { - 't': 'read_bit', - 'b': 'read_octet', - 'B': 'read_signed_octet', - 'U': 'read_signed_short', - 'u': 'read_short', - 'I': 'read_signed_long', - 'i': 'read_long', - 'L': 'read_unsigned_long_long', - 'l': 'read_long_long', - 'f': 'read_float', - 'd': 'read_float', - 'D': 'read_decimal', - 's': 'read_shortstr', - 'S': 'read_longstr', - 'A': 'read_field_array', - 'T': 'read_timestamp', - 'F': 'read_table', - } - - def read_table_subitem(self, table_data): - """Read `table_data` bytes, guess the type of the value, and cast it. - - table_data: a pair of b'' - """ - value_type = chr(table_data.read_octet()) - if value_type == 'V': - return None - else: - reader_name = self._table_subitem_reader_map.get(value_type) - if not reader_name: - raise ValueError('Unknown value_type {}'.format(value_type)) - return getattr(table_data, reader_name)() - - def read_field_array(self): - array_len = self.read_long() - array_data = AmqpDecoder(io.BytesIO(self.reader.read(array_len))) - field_array = [] - while array_data.reader.tell() < array_len: - item = self.read_table_subitem(array_data) - field_array.append(item) - return field_array - - -class AmqpRequest: - def __init__(self, frame_type, channel): - self.frame_type = frame_type - self.channel = channel - self.class_id = None - self.weight = None - self.method_id = None - self.next_body_size = None - - def declare_class(self, class_id, weight=0): - self.class_id = class_id - self.weight = 0 - - def set_body_size(self, size): - self.next_body_size = size - - def declare_method(self, class_id, method_id): - self.class_id = class_id - self.method_id = method_id - - def get_frame(self, encoder): - payload = encoder.payload - content_header = '' - transmission = io.BytesIO() - if self.frame_type == amqp_constants.TYPE_METHOD: - content_header = struct.pack('!HH', self.class_id, self.method_id) - elif self.frame_type == amqp_constants.TYPE_HEADER: - content_header = struct.pack('!HHQ', self.class_id, self.weight, self.next_body_size) - elif self.frame_type == amqp_constants.TYPE_BODY: - # no specific headers - pass - elif self.frame_type == amqp_constants.TYPE_HEARTBEAT: - # no specific headers - pass - else: - raise Exception("frame_type {} not handled".format(self.frame_type)) - - header = struct.pack( - '!BHI', self.frame_type, self.channel, - payload.tell() + len(content_header) - ) - transmission.write(header) - if content_header: - transmission.write(content_header) - transmission.write(payload.getvalue()) - transmission.write(amqp_constants.FRAME_END) - return transmission.getvalue() - - -class AmqpResponse: - """Read a response from the AMQP server + writer: anyio Stream + channel: amqp Channel identifier + encoder: frame encoder from pamqp which can be marshalled """ + await writer.send_all(pamqp.frame.marshal(encoder, channel)) - def __init__(self, reader): - self.reader = reader - self.frame_type = None - self.channel = 0 # default channel in AMQP - self.payload_size = None - self.frame_end = None - self.frame_payload = None - self.payload = None - self.frame_class = None - self.frame_method = None - self.class_id = None - self.method_id = None - self.weight = None - self.body_size = None - self.property_flags = None - self.properties = None - self.arguments = {} - self.frame_length = 0 - self.payload_decoder = None - self.header_decoder = None +async def read(reader): + """Read a new frame from the wire - async def _readexactly(self, length): - data = b"" - while len(data) < length: - d = await self.reader.receive_some(length - len(data)) - if len(d) == 0: - raise EOFError - data += d - return data + reader: anyio Stream - async def read_frame(self): - """Decode the frame""" - if not self.reader: - raise exceptions.AmqpClosedConnection() - try: - data = await self._readexactly(7) - except (EOFError, socket.error) as ex: - raise exceptions.AmqpClosedConnection() from ex - - frame_header = io.BytesIO(data) - self.header_decoder = AmqpDecoder(frame_header) - self.frame_type = self.header_decoder.read_octet() - self.channel = self.header_decoder.read_short() - self.frame_length = self.header_decoder.read_long() - payload_data = await self._readexactly(self.frame_length) - - if self.frame_type == amqp_constants.TYPE_METHOD: - self.payload = io.BytesIO(payload_data) - self.payload_decoder = AmqpDecoder(self.payload) - self.class_id = self.payload_decoder.read_short() - self.method_id = self.payload_decoder.read_short() - - elif self.frame_type == amqp_constants.TYPE_HEADER: - self.payload = io.BytesIO(payload_data) - self.payload_decoder = AmqpDecoder(self.payload) - self.class_id = self.payload_decoder.read_short() - self.weight = self.payload_decoder.read_short() - self.body_size = self.payload_decoder.read_long_long() - self.property_flags = 0 - for flagword_index in count(0): - partial_flags = self.payload_decoder.read_short() - self.property_flags |= partial_flags << (flagword_index * 16) - if partial_flags & 1 == 0: - break - decoded_properties = {} - if self.property_flags & amqp_constants.FLAG_CONTENT_TYPE: - decoded_properties['content_type'] = self.payload_decoder.read_shortstr() - if self.property_flags & amqp_constants.FLAG_CONTENT_ENCODING: - decoded_properties['content_encoding'] = self.payload_decoder.read_shortstr() - if self.property_flags & amqp_constants.FLAG_HEADERS: - decoded_properties['headers'] = self.payload_decoder.read_table() - if self.property_flags & amqp_constants.FLAG_DELIVERY_MODE: - decoded_properties['delivery_mode'] = self.payload_decoder.read_octet() - if self.property_flags & amqp_constants.FLAG_PRIORITY: - decoded_properties['priority'] = self.payload_decoder.read_octet() - if self.property_flags & amqp_constants.FLAG_CORRELATION_ID: - decoded_properties['correlation_id'] = self.payload_decoder.read_shortstr() - if self.property_flags & amqp_constants.FLAG_REPLY_TO: - decoded_properties['reply_to'] = self.payload_decoder.read_shortstr() - if self.property_flags & amqp_constants.FLAG_EXPIRATION: - decoded_properties['expiration'] = self.payload_decoder.read_shortstr() - if self.property_flags & amqp_constants.FLAG_MESSAGE_ID: - decoded_properties['message_id'] = self.payload_decoder.read_shortstr() - if self.property_flags & amqp_constants.FLAG_TIMESTAMP: - decoded_properties['timestamp'] = self.payload_decoder.read_long_long() - if self.property_flags & amqp_constants.FLAG_TYPE: - decoded_properties['type'] = self.payload_decoder.read_shortstr() - if self.property_flags & amqp_constants.FLAG_USER_ID: - decoded_properties['user_id'] = self.payload_decoder.read_shortstr() - if self.property_flags & amqp_constants.FLAG_APP_ID: - decoded_properties['app_id'] = self.payload_decoder.read_shortstr() - if self.property_flags & amqp_constants.FLAG_CLUSTER_ID: - decoded_properties['cluster_id'] = self.payload_decoder.read_shortstr() - self.properties = Properties(**decoded_properties) + Returns (channel, frame) a tuple containing both channel and the pamqp frame, + the object describing the frame + """ + if not reader: + raise exceptions.AmqpClosedConnection() + try: + data = await reader.receive_exactly(7) + except (ClosedResourceError, IncompleteRead): + raise exceptions.AmqpClosedConnection() from ex - elif self.frame_type == amqp_constants.TYPE_BODY: - self.payload = payload_data + frame_type, channel, frame_length = pamqp.frame.frame_parts(data) - elif self.frame_type == amqp_constants.TYPE_HEARTBEAT: - pass + payload_data = await reader.receive_exactly(frame_length) + frame = None - else: - raise ValueError("Message type {:x} not known".format(self.frame_type)) - self.frame_end = await self._readexactly(1) - assert self.frame_end == amqp_constants.FRAME_END + if frame_type == amqp_constants.TYPE_METHOD: + frame = pamqp.frame._unmarshal_method_frame(payload_data) - def __str__(self): - frame_data = { - 'type': self.frame_type, - 'channel': self.channel, - 'size': self.payload_size, - 'frame_end': self.frame_end, - 'payload': self.frame_payload, - } - output = """ -0 1 3 7 size+7 size+8 -+--------+-----------+------------+ +---------------+ +--------------+ -|{type!r:^8}|{channel!r:^11}|{size!r:^12}| |{payload!r:^15}| |{frame_end!r:^14}| -+--------+-----------+------------+ +---------------+ +--------------+ - type channel size payload frame-end -""".format(**frame_data) # noqa: E501 + elif frame_type == amqp_constants.TYPE_HEADER: + frame = pamqp.frame._unmarshal_header_frame(payload_data) - if self.frame_type == amqp_constants.TYPE_METHOD: - method_data = { - 'class_id': self.class_id, - 'method_id': self.method_id, - } - type_output = """ -0 2 4 -+----------+-----------+-------------- - - -|{class_id:^10}|{method_id:^11}| arguments... -+----------+-----------+-------------- - - - class-id method-id ...""".format(**method_data) + elif frame_type == amqp_constants.TYPE_BODY: + frame = pamqp.frame._unmarshal_body_frame(payload_data) - output += os.linesep + type_output + elif frame_type == amqp_constants.TYPE_HEARTBEAT: + frame = pamqp.heartbeat.Heartbeat() - return output + frame_end = await reader.receive_exactly(1) + assert frame_end == amqp_constants.FRAME_END + return channel, frame diff --git a/asyncamqp/properties.py b/asyncamqp/properties.py index 8ca6f47..54b906f 100644 --- a/asyncamqp/properties.py +++ b/asyncamqp/properties.py @@ -1,3 +1,4 @@ +# pylint: disable=redefined-builtin from .constants import MESSAGE_PROPERTIES @@ -17,11 +18,11 @@ def __init__( expiration=None, message_id=None, timestamp=None, - type=None, + message_type=None, user_id=None, app_id=None, cluster_id=None - ): # pylint: disable=redefined-builtin + ): self.content_type = content_type self.content_encoding = content_encoding self.headers = headers @@ -32,7 +33,26 @@ def __init__( self.expiration = expiration self.message_id = message_id self.timestamp = timestamp - self.type = type + self.message_type = message_type self.user_id = user_id self.app_id = app_id self.cluster_id = cluster_id + + +def from_pamqp(instance): + props = Properties() + props.content_type = instance.content_type + props.content_encoding = instance.content_encoding + props.headers = instance.headers + props.delivery_mode = instance.delivery_mode + props.priority = instance.priority + props.correlation_id = instance.correlation_id + props.reply_to = instance.reply_to + props.expiration = instance.expiration + props.message_id = instance.message_id + props.timestamp = instance.timestamp + props.message_type = instance.message_type + props.user_id = instance.user_id + props.app_id = instance.app_id + props.cluster_id = instance.cluster_id + return props diff --git a/asyncamqp/protocol.py b/asyncamqp/protocol.py index 716d082..7423720 100644 --- a/asyncamqp/protocol.py +++ b/asyncamqp/protocol.py @@ -10,12 +10,12 @@ import ssl from async_generator import asynccontextmanager from async_generator import async_generator,yield_ +import pamqp from . import channel as amqp_channel from . import constants as amqp_constants from . import frame as amqp_frame from . import exceptions -from . import _version logger = logging.getLogger(__name__) @@ -202,10 +202,10 @@ async def _drain(self): # # version of Python where this bugs exists is supported anymore. # await self._stream_writer.drain() - async def _write_frame(self, frame, encoder, drain=True): + async def _write_frame(self, channel_id, request, drain=True): # Doesn't actually write frame, pushes it for _writer_loop task to # pick it up. - await self._send_queue.put((frame, encoder)) + await self._send_queue.put((channel_id, request)) async def _writer_loop(self, done): async with anyio.open_cancel_scope(shield=True) as scope: @@ -218,14 +218,13 @@ async def _writer_loop(self, done): timeout = inf async with anyio.move_on_after(timeout) as timeout_scope: - frame, encoder = await self._send_queue.get() + channel_id, request = await self._send_queue.get() if timeout_scope.cancel_called: await self.send_heartbeat() continue - f = frame.get_frame(encoder) try: - await self._stream.send_all(f) + await amqp_frame.write(self._stream, channel_id, request) except (anyio.exceptions.ClosedResourceError, BrokenPipeError): # raise exceptions.AmqpClosedConnection(self) from None # the reader will raise the error also @@ -248,18 +247,14 @@ async def close(self, no_wait=False): await self._close_channels() # If the closing handshake is in progress, let it complete. - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, 0) - frame.declare_method( - amqp_constants.CLASS_CONNECTION, amqp_constants.CONNECTION_CLOSE + request = pamqp.specification.Connection.Close( + reply_code=0, + reply_text='', + class_id=0, + method_id=0 ) - encoder = amqp_frame.AmqpEncoder() - # we request a clean connection close - encoder.write_short(0) - encoder.write_shortstr('') - encoder.write_short(0) - encoder.write_short(0) try: - await self._write_frame(frame, encoder) + await self._write_frame(0, request) except anyio.exceptions.ClosedResourceError: pass except Exception: @@ -351,14 +346,11 @@ async def __aenter__(self): 'consumer_cancel_notify': True, 'connection.blocked': False, }, - 'copyright': 'BSD', - 'product': _version.__package__, - 'product_version': _version.__version__, } client_properties.update(self.client_properties) # waiting reply start with credentions and co - await self.start_ok(client_properties, 'AMQPLAIN', self._auth, self.server_locales[0]) + await self.start_ok(client_properties, 'AMQPLAIN', self._auth, self.server_locales) # wait for a "tune" reponse await self.dispatch_frame() @@ -411,9 +403,8 @@ async def get_frame(self): """Read the frame, and only decode its header """ - frame = amqp_frame.AmqpResponse(self._stream) try: - await frame.read_frame() + channel, frame = await amqp_frame.read(self._stream) except ConnectionResetError: raise exceptions.AmqpClosedConnection(self) from None except EnvironmentError as err: @@ -423,46 +414,36 @@ async def get_frame(self): except anyio.exceptions.ClosedResourceError: raise exceptions.AmqpClosedConnection(self) from None - return frame + return channel, frame - async def dispatch_frame(self, frame=None): + async def dispatch_frame(self, frame_channel=None, frame=None): """Dispatch the received frame to the corresponding handler""" method_dispatch = { - (amqp_constants.CLASS_CONNECTION, - amqp_constants.CONNECTION_CLOSE): # noqa: E131 - self.server_close, - (amqp_constants.CLASS_CONNECTION, - amqp_constants.CONNECTION_CLOSE_OK): # noqa: E131 - self.close_ok, - (amqp_constants.CLASS_CONNECTION, - amqp_constants.CONNECTION_TUNE): # noqa: E131 - self.tune, - (amqp_constants.CLASS_CONNECTION, - amqp_constants.CONNECTION_START): # noqa: E131 - self.start, - (amqp_constants.CLASS_CONNECTION, - amqp_constants.CONNECTION_OPEN_OK): # noqa: E131 - self.open_ok, + pamqp.specification.Connection.Close.name: self.server_close, + pamqp.specification.Connection.CloseOk.name: self.close_ok, + pamqp.specification.Connection.Tune.name: self.tune, + pamqp.specification.Connection.Start.name: self.start, + pamqp.specification.Connection.OpenOk.name: self.open_ok, } if not frame: - frame = await self.get_frame() + frame_channel, frame = await self.get_frame() - if frame.frame_type == amqp_constants.TYPE_HEARTBEAT: + if isinstance(frame, pamqp.heartbeat.Heartbeat): return - if frame.channel is not 0: - channel = self.channels.get(frame.channel) + if frame_channel: + channel = self.channels.get(frame_channel) if channel is not None: await channel.dispatch_frame(frame) else: - logger.info("Unknown channel %s", frame.channel) + logger.info("Unknown channel %s", frame_channel) return - if (frame.class_id, frame.method_id) not in method_dispatch: - logger.info("frame %s %s is not handled", frame.class_id, frame.method_id) + if frame.name not in method_dispatch: + logger.info("frame %s is not handled", frame.name) return - await method_dispatch[(frame.class_id, frame.method_id)](frame) + await method_dispatch[frame.name](frame) def release_channel_id(self, channel_id): """Called from the channel instance, it relase a previously used @@ -545,39 +526,37 @@ async def send_heartbeat(self): It can be an ack for the server or the client willing to check for the connexion timeout """ - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_HEARTBEAT, 0) - request = amqp_frame.AmqpEncoder() - await self._write_frame(frame, request) + request = pamqp.heartbeat.Heartbeat() + await self._write_frame(0, request) # Amqp specific methods async def start(self, frame): """Method sent from the server to begin a new connection""" - response = amqp_frame.AmqpDecoder(frame.payload) - - self.version_major = response.read_octet() - self.version_minor = response.read_octet() - self.server_properties = response.read_table() - self.server_mechanisms = response.read_longstr().split(' ') - self.server_locales = response.read_longstr().split(' ') + self.version_major = frame.version_major + self.version_minor = frame.version_minor + self.server_properties = frame.server_properties + self.server_mechanisms = frame.mechanisms + self.server_locales = frame.locales async def start_ok(self, client_properties, mechanism, auth, locale): - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, 0) - frame.declare_method(amqp_constants.CLASS_CONNECTION, amqp_constants.CONNECTION_START_OK) - request = amqp_frame.AmqpEncoder() - request.write_table(client_properties) - request.write_shortstr(mechanism) - request.write_table(auth) - request.write_shortstr(locale.encode()) - await self._write_frame(frame, request) + def credentials(): + return '\0{LOGIN}\0{PASSWORD}'.format(**auth) + + request = pamqp.specification.Connection.StartOk( + client_properties=client_properties, + mechanism=mechanism, + locale=locale, + response=credentials() + ) + await self._write_frame(0, request) async def server_close(self, frame): """The server is closing the connection""" self.state = CLOSING - response = amqp_frame.AmqpDecoder(frame.payload) - reply_code = response.read_short() - reply_text = response.read_shortstr() - class_id = response.read_short() - method_id = response.read_short() + reply_code = frame.reply_code + reply_text = frame.reply_text + class_id = frame.class_id + method_id = frame.method_id self._close_reason = dict(text=reply_text, code=reply_code, class_id=class_id, method_id=method_id) logger.warning( "Server closed connection: %s, code=%s, class_id=%s, method_id=%s", reply_text, @@ -587,10 +566,8 @@ async def server_close(self, frame): await self._close_ok() async def _close_ok(self): - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, 0) - frame.declare_method(amqp_constants.CLASS_CONNECTION, amqp_constants.CONNECTION_CLOSE_OK) - request = amqp_frame.AmqpEncoder() - await self._write_frame(frame, request) + request = pamqp.specification.Connection.CloseOk() + await self._write_frame(0, request) await anyio.sleep(0) # give the write task one shot to send the frame if self._nursery is not None: await self._cancel_all() @@ -604,34 +581,25 @@ async def _cancel_all(self): await self._nursery.cancel_scope.cancel() async def tune(self, frame): - decoder = amqp_frame.AmqpDecoder(frame.payload) - self.server_channel_max = decoder.read_short() - self.server_frame_max = decoder.read_long() - self.server_heartbeat = decoder.read_short() + self.server_channel_max = frame.channel_max + self.server_frame_max = frame.frame_max + self.server_heartbeat = frame.heartbeat async def tune_ok(self, channel_max, frame_max, heartbeat): - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, 0) - frame.declare_method(amqp_constants.CLASS_CONNECTION, amqp_constants.CONNECTION_TUNE_OK) - encoder = amqp_frame.AmqpEncoder() - encoder.write_short(channel_max) - encoder.write_long(frame_max) - encoder.write_short(heartbeat) - - await self._write_frame(frame, encoder) + request = pamqp.specification.Connection.TuneOk( + channel_max, frame_max, heartbeat + ) + await self._write_frame(0, request) async def secure_ok(self, login_response): pass async def open(self, virtual_host, capabilities='', insist=False): """Open connection to virtual host.""" - frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, 0) - frame.declare_method(amqp_constants.CLASS_CONNECTION, amqp_constants.CONNECTION_OPEN) - encoder = amqp_frame.AmqpEncoder() - encoder.write_shortstr(virtual_host) - encoder.write_shortstr(capabilities) - encoder.write_bool(insist) - - await self._write_frame(frame, encoder) + request = pamqp.specification.Connection.Open( + virtual_host, capabilities, insist + ) + await self._write_frame(0, request) async def open_ok(self, frame): self.state = OPEN diff --git a/ci/requirements_dev.txt b/ci/requirements_dev.txt index d5c1c06..20f8a03 100644 --- a/ci/requirements_dev.txt +++ b/ci/requirements_dev.txt @@ -6,4 +6,7 @@ coverage pylint pytest pytest-trio --e git+https://github.com/bkjones/pyrabbit.git#egg=pyrabbit +Sphinx +sphinx-rtd-theme + +pyrabbit2 diff --git a/debian/control b/debian/control index a028fe2..e1b4714 100644 --- a/debian/control +++ b/debian/control @@ -8,7 +8,7 @@ Build-Depends: python3-setuptools, python3-all, debhelper (>= 9), python3-pytest-trio, python3-setuptools-scm, python3-requests (>> 2.21), - python3-pyrabbit, + python3-pyrabbit2, Standards-Version: 3.9.1 Package: python3-asyncamqp @@ -16,7 +16,7 @@ Architecture: all Depends: ${misc:Depends}, ${python3:Depends}, python3 (>> 3.6), python3-anyio, - python3-pyrabbit, + python3-pyrabbit2, Description: AMQP implementation using anyio The asyncamqp library is a pure-Python implementation of the AMQP 0.9.1 protocol. diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000..f9d4994 --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,17 @@ +version: '3' +services: + aioamqp-test: + build: . + command: ["make", "test"] + depends_on: + - rabbitmq + environment: + - AMQP_HOST=rabbitmq + rabbitmq: + hostname: rabbitmq + image: rabbitmq:3-management + environment: + - RABBITMQ_NODENAME=my-rabbit + ports: + - 15672 + - 5672 diff --git a/docs/api.rst b/docs/api.rst index ec390ae..0fb610d 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -95,8 +95,8 @@ Returned messages will be delivered to your code in an async iterator over the c await taskgroup.spawn(do_returns, chan) do_whatever() -The code above ensures that the iterator is started before calling ``do_whatever()``. -Returned messages arriving before that will be logged and discarded. +The code above ensures that the iterator is started before calling ``do_whatever()``, +ensuring that returned messages will be processed properly. Consuming messages ------------------ @@ -130,7 +130,7 @@ from the queue:: expiration message_id timestamp - type + message_type user_id app_id cluster_id @@ -140,6 +140,19 @@ Remember that you need to call either ``basic_ack(delivery_tag)`` or server will not know that you processed it, and thus will not send more messages. +Server Cancellation +~~~~~~~~~~~~~~~~~~~ + +RabbitMQ offers an AMQP extension to notify a consumer when a queue is deleted. +See `Consumer Cancel Notification `_ +for additional details. ``asyncamqp`` enables the extension for all channels +and terminates the channel's receiver loop when the consumer is cancelled:: + + async with chan.new_consumer(queue_name="my_queue") as listener: + async for body, envelope, properties in listener: + process_message(body, envelope, properties) + print("I get here when the queue is deleted") + Queues ------ @@ -277,4 +290,3 @@ Note: the `internal` flag is deprecated and not used in this library. :param str routing_key: the key used to filter messages :param bool no_wait: if set, the server will not respond to the method :param dict arguments: AMQP arguments to be passed when removing the exchange. - diff --git a/docs/changelog.rst b/docs/changelog.rst index ad861d6..0b43eda 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,9 +1,19 @@ Changelog ========= -Trio-amqp +Asyncamqp +++++++++ +Asyncamqp 0.3 +------------- + + * Merge to Aioamqp 0.14. + +Asyncamqp 0.2 +------------- + + * Rewrote the whole package to use AnyIO instead. + Trio-amqp 0.1 ------------- @@ -11,9 +21,32 @@ Trio-amqp 0.1 * Changed AmqpProtocol to be an async context manager while the connection is active. + Next release ------------ +Aioamqp 0.14.0 +-------------- + + * Fix ``waiter already exist`` issue when creating multiple queues (closes #105). + * Rename ``type`` to ``message_type`` in constant.Properties object to be full compatible with pamqp. + * Add python 3.8 support. + +Aioamqp 0.13.0 +-------------- + + * SSL Connections must be configured with an SSLContext object in ``connect`` and ``from_url`` (closes #142). + * Uses pamqp to encode or decode protocol frames. + * Drops support of python 3.3 and python 3.4. + * Uses async and await keywords. + * Fix pamqp `_frame_parts` call, now uses exposed `frame_parts` + +Aioamqp 0.12.0 +-------------- + + * Fix an issue to use correct int encoder depending on int size (closes #180). + * Call user-specified callback when a consumer is cancelled. + Aioamqp 0.11.0 -------------- diff --git a/docs/examples/hello_world.rst b/docs/examples/hello_world.rst index b1658d5..c2a9265 100644 --- a/docs/examples/hello_world.rst +++ b/docs/examples/hello_world.rst @@ -70,4 +70,3 @@ To consume a message, the library calls a callback (which **MUST** be a coroutin print(body) await channel.basic_consume(callback, queue_name='hello', no_ack=True) - diff --git a/docs/examples/rpc.rst b/docs/examples/rpc.rst index 040b816..3c04e80 100644 --- a/docs/examples/rpc.rst +++ b/docs/examples/rpc.rst @@ -65,4 +65,3 @@ response from this request. ) await channel.basic_client_ack(delivery_tag=envelope.delivery_tag) - diff --git a/docs/examples/work_queue.rst b/docs/examples/work_queue.rst index 90096b7..4ce2fbd 100644 --- a/docs/examples/work_queue.rst +++ b/docs/examples/work_queue.rst @@ -56,4 +56,3 @@ takes time. await anyio.sleep(body.count(b'.')) print(" [x] Done") await channel.basic_client_ack(delivery_tag=envelope.delivery_tag) - diff --git a/examples/emit_log.py b/examples/emit_log.py index 5351248..d22fb2a 100755 --- a/examples/emit_log.py +++ b/examples/emit_log.py @@ -23,7 +23,6 @@ async def exchange_routing(): await channel.exchange_declare( exchange_name=exchange_name, type_name='fanout' ) - await channel.basic_publish( message, exchange_name=exchange_name, routing_key='' ) diff --git a/examples/emit_log_direct.py b/examples/emit_log_direct.py index 8d32277..4afd46d 100755 --- a/examples/emit_log_direct.py +++ b/examples/emit_log_direct.py @@ -26,10 +26,11 @@ async def exchange_routing(): message, exchange_name=exchange_name, routing_key=severity ) print(" [x] Sent %r" % (message,)) + await protocol.close() + transport.close() except asyncamqp.AmqpClosedConnection: print("closed connections") return - anyio.run(exchange_routing) diff --git a/examples/receive_log_direct.py b/examples/receive_log_direct.py index b1b9527..0ce0ce6 100644 --- a/examples/receive_log_direct.py +++ b/examples/receive_log_direct.py @@ -13,12 +13,7 @@ async def callback(channel, body, envelope, properties): - print( - "consumer {} recved {} ({})".format( - envelope.consumer_tag, body, envelope.delivery_tag - ) - ) - + print("consumer {} recved {} ({})".format(envelope.consumer_tag, body, envelope.delivery_tag)) async def receive_log(): try: diff --git a/examples/send_with_return.py b/examples/send_with_return.py index de5fcd1..fdb9ab4 100644 --- a/examples/send_with_return.py +++ b/examples/send_with_return.py @@ -19,8 +19,7 @@ async def handle_return(channel, body, envelope, properties): envelope.reply_text, envelope.exchange_name)) - async def get_returns(chan): - task_status.started() +async def get_returns(chan): # DO NOT await() between these statements async for body, envelope, properties in chan: await handle_return(channel, body, envelope, properties) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..0ab7d0b --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[bdist_wheel] +python-tag = py35.py36.py37.py38 diff --git a/setup.py b/setup.py index 98b9479..9a97831 100644 --- a/setup.py +++ b/setup.py @@ -10,16 +10,20 @@ url='https://github.com/python-trio/asyncamqp', description=description, long_description=open('README.rst').read(), - download_url='https://pypi.python.org/pypi/asyncamqp', + # download_url='https://pypi.python.org/pypi/asyncamqp', setup_requires=[ 'pyrabbit', ], install_requires=[ 'anyio', ], + keywords=['asyncio', 'amqp', 'rabbitmq', 'aio'], packages=[ 'asyncamqp', ], + install_requires=[ + 'pamqp>=2.2.0,<3', + ], classifiers=[ "Development Status :: 4 - Beta", "Intended Audience :: Developers", @@ -30,7 +34,10 @@ "Programming Language :: Python :: 3.5", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", "Framework :: Trio", + "Framework :: Anyio", ], platforms='all', license='BSD' diff --git a/tests/test_connect.py b/tests/test_connect.py index c5f38a0..7a487b6 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -22,6 +22,8 @@ async def test_connect_tuning(self): channel_max = 10 heartbeat = 100 proto = testcase.connect( + host=self.host, + port=self.port, virtualhost=self.vhost, channel_max=channel_max, frame_max=frame_max, @@ -45,8 +47,8 @@ async def test_connect_tuning(self): @pytest.mark.trio async def test_socket_nodelay(self): self.reset_vhost() - proto = testcase.connect(virtualhost=self.vhost) + proto = testcase.connect(host=self.host, port=self.port, virtualhost=self.vhost) async with proto as amqp: sock = amqp._stream opt_val = sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - assert opt_val == 1, opt_val + assert opt_val > 0 diff --git a/tests/test_frame.py b/tests/test_frame.py deleted file mode 100644 index 1020e76..0000000 --- a/tests/test_frame.py +++ /dev/null @@ -1,134 +0,0 @@ -""" - Test frame format. -""" - -import io -import pytest -import sys -import datetime - -from decimal import Decimal - -from asyncamqp import constants as amqp_constants -from asyncamqp import frame as frame_module -from asyncamqp.frame import AmqpEncoder -from asyncamqp.frame import AmqpResponse - - -class TestEncoder: - """Test encoding of python builtin objects to AMQP frames.""" - - _multiprocess_can_split_ = True - - def setup(self): - self.encoder = AmqpEncoder() - - def test_write_string(self): - self.encoder.write_value("foo") - assert self.encoder.payload.getvalue() == \ - b'S\x00\x00\x00\x03foo' - # 'S' + size (4 bytes) + payload - - def test_write_bool(self): - self.encoder.write_value(True) - assert self.encoder.payload.getvalue() == b't\x01' - - def test_write_array(self): - self.encoder.write_value(["v1", 123]) - assert self.encoder.payload.getvalue() == \ - b'A\x00\x00\x00\x0cS\x00\x00\x00\x02v1I\x00\x00\x00{' - # total size (4 bytes) + 'S' + size (4 bytes) + payload + 'I' + - # size (4 bytes) + payload - - def test_write_float(self): - self.encoder.write_value(1.1) - assert self.encoder.payload.getvalue() == b'd?\xf1\x99\x99\x99\x99\x99\x9a' - - def test_write_decimal(self): - self.encoder.write_value(Decimal("-1.1")) - assert self.encoder.payload.getvalue() == b'D\x01\xff\xff\xff\xf5' - - self.encoder.write_value(Decimal("1.1")) - assert self.encoder.payload.getvalue() == b'D\x01\xff\xff\xff\xf5D\x01\x00\x00\x00\x0b' - - def test_write_datetime(self): - self.encoder.write_value(datetime.datetime(2017, 12, 10, 4, 6, 49, 548918)) - assert self.encoder.payload.getvalue() == b'T\x00\x00\x00\x00Z,\xb2\xd9' - - def test_write_dict(self): - self.encoder.write_value({'foo': 'bar', 'bar': 'baz'}) - assert self.encoder.payload.getvalue() in \ - (b'F\x00\x00\x00\x18\x03barS\x00\x00\x00\x03baz\x03fooS\x00\x00\x00\x03bar', # noqa: E501 - b'F\x00\x00\x00\x18\x03fooS\x00\x00\x00\x03bar\x03barS\x00\x00\x00\x03baz') # noqa: E501 - # 'F' + total size + key (always a string) + value (with type) + ... - # The keys are not ordered, so the output is not deterministic - # (two possible values) - - def test_write_none(self): - self.encoder.write_value(None) - assert self.encoder.payload.getvalue() == b'V' - - def test_write_message_properties_dont_crash(self): - properties = { - 'content_type': 'plain/text', - 'content_encoding': 'utf8', - 'headers': { - 'key': 'value' - }, - 'delivery_mode': 2, - 'priority': 10, - 'correlation_id': '122', - 'reply_to': 'joe', - 'expiration': 'someday', - 'message_id': 'm_id', - 'timestamp': 12345, - 'type': 'a_type', - 'user_id': 'joe_42', - 'app_id': 'roxxor_app', - 'cluster_id': 'a_cluster', - } - self.encoder.write_message_properties(properties) - assert len(self.encoder.payload.getvalue()) != 0 - - def test_write_message_correlation_id_encode(self): - properties = { - 'delivery_mode': 2, - 'priority': 0, - 'correlation_id': '122', - } - self.encoder.write_message_properties(properties) - assert self.encoder.payload.getvalue() == b'\x1c\x00\x02\x00\x03122' - - def test_write_message_priority_zero(self): - properties = { - 'delivery_mode': 2, - 'priority': 0, - } - self.encoder.write_message_properties(properties) - assert self.encoder.payload.getvalue() == b'\x18\x00\x02\x00' - - def test_write_message_properties_raises_on_invalid_property_name(self): - properties = { - 'invalid': 'coucou', - } - with pytest.raises(ValueError): - self.encoder.write_message_properties(properties) - - -class TestAmqpResponse: - def test_dump_dont_crash(self): - frame = AmqpResponse(None) - frame.frame_type = amqp_constants.TYPE_METHOD - frame.class_id = 0 - frame.method_id = 0 - saved_stout = sys.stdout - frame_module.DUMP_FRAMES = True - sys.stdout = io.StringIO() - try: - last_len = len(sys.stdout.getvalue()) - print(self) - # assert something has been writen - assert len(sys.stdout.getvalue()) > last_len - finally: - frame_module.DUMP_FRAMES = False - sys.stdout = saved_stout diff --git a/tests/test_protocol.py b/tests/test_protocol.py index b95cbf0..b4d13cd 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -15,7 +15,7 @@ class TestProtocol(testcase.RabbitTestCase): @pytest.mark.trio async def test_connect(self): self.reset_vhost() - amqp = testcase.connect(virtualhost=self.vhost) + amqp = testcase.connect(host=self.host, port=self.port, virtualhost=self.vhost) async with amqp as protocol: assert protocol.state == OPEN @@ -27,7 +27,7 @@ async def test_connect_products_info(self): 'program_version': '0.1.1', } amqp = testcase.connect( - virtualhost=self.vhost, + host=self.host, port=self.port, virtualhost=self.vhost, client_properties=client_properties, ) async with amqp as protocol: @@ -37,7 +37,7 @@ async def test_connect_products_info(self): async def test_connection_unexistant_vhost(self): self.reset_vhost() with pytest.raises(exceptions.AmqpClosedConnection): - amqp = testcase.connect(virtualhost='/unexistant') + amqp = testcase.connect(host=self.host, port=self.port, virtualhost='/unexistant') async with amqp: pass diff --git a/tests/test_publish.py b/tests/test_publish.py index 0dac747..6b1b877 100644 --- a/tests/test_publish.py +++ b/tests/test_publish.py @@ -20,6 +20,21 @@ async def test_publish(self, channel): await self.check_messages(channel.protocol, "q", 1) + @pytest.mark.trio + async def test_empty_publish(self): + # declare + await self.channel.queue_declare("q", exclusive=True, no_wait=False) + await self.channel.exchange_declare("e", "fanout") + await self.channel.queue_bind("q", "e", routing_key='') + + # publish + await self.channel.publish("", "e", routing_key='') + + queues = self.list_queues() + assert "q" in queues + assert queues["q"]["messages"] == 1 + assert queues["q"]["message_bytes"] == 0 + @pytest.mark.trio async def test_big_publish(self, channel): # declare diff --git a/tests/test_queue.py b/tests/test_queue.py index c24958a..db9bfbb 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -47,6 +47,20 @@ async def test_queue_declare_passive(self, channel): assert result['consumer_count'] == 0, result assert channel.protocol.local_name(result['queue']) == queue_name, result + @pytest.mark.trio + async def test_queue_declare_custom_x_message_ttl_32_bits(self): + queue_name = 'queue_name' + # 2147483648 == 10000000000000000000000000000000 + # in binary, meaning it is 32 bit long + x_message_ttl = 2147483648 + result = await self.channel.queue_declare('queue_name', arguments={ + 'x-message-ttl': x_message_ttl + }) + assert result['message_count'] == 0 + assert result['consumer_count'] == 0 + assert result['queue'].split('.')[-1] == queue_name + assert result + @pytest.mark.trio async def test_queue_declare_passive_nonexistant_queue(self, channel): queue_name = 'q17' diff --git a/tests/testcase.py b/tests/testcase.py index 6705f70..8e215aa 100644 --- a/tests/testcase.py +++ b/tests/testcase.py @@ -13,7 +13,7 @@ from functools import wraps from async_generator import asynccontextmanager -import pyrabbit.api +import pyrabbit2 as pyrabbit from . import testcase from asyncamqp import exceptions, connect_amqp @@ -109,7 +109,7 @@ def reset_vhost(): port = int(os.environ.get('AMQP_PORT', 5672)) vhost = os.environ.get('AMQP_VHOST', 'test' + str(uuid.uuid4())) http_client = pyrabbit.api.Client( - '%s:%s/api/' % (host, 10000 + port), 'guest', 'guest', timeout=20 + '%s:%s' % (host, 10000 + port), 'guest', 'guest', timeout=20 ) try: http_client.create_vhost(vhost) @@ -305,7 +305,7 @@ async def safe_queue_delete(self, queue_name, channel): except TimeoutError: logger.warning('Timeout on queue %s deletion', full_queue_name, exc_info=True) except Exception: # pylint: disable=broad-except - logger.error('Unexpected error on queue %s deletion', full_queue_name, exc_info=True) + logger.exception('Unexpected error on queue %s deletion', full_queue_name) async def safe_exchange_delete(self, exchange_name, channel=None): """Delete the exchange but does not raise any exception if it fails @@ -319,9 +319,7 @@ async def safe_exchange_delete(self, exchange_name, channel=None): except TimeoutError: logger.warning('Timeout on exchange %s deletion', full_exchange_name, exc_info=True) except Exception: # pylint: disable=broad-except - logger.error( - 'Unexpected error on exchange %s deletion', full_exchange_name, exc_info=True - ) + logger.exception('Unexpected error on exchange %s deletion', full_exchange_name) async def queue_declare(self, queue_name, *args, channel=None, safe_delete_before=True, **kw): channel = channel or self.channel diff --git a/tox.ini b/tox.ini index ef3712d..1950753 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py35, py36 +envlist = py36, py37, py38 skipsdist = true skip_missing_interpreters = true