diff --git a/src/workflows/services/sample_consumer.py b/src/workflows/services/sample_consumer.py index 338bcb27..9dd7183b 100644 --- a/src/workflows/services/sample_consumer.py +++ b/src/workflows/services/sample_consumer.py @@ -24,13 +24,13 @@ def consume_message(self, header, message): t = (time.time() % 1000) * 1000 if header: - header = json.dumps(header, indent=2) + "\n" + "----------------" + "\n" + header_str = json.dumps(header, indent=2) + "\n" + "----------------" + "\n" else: - header = "" + header_str = "" if isinstance(message, dict): message = json.dumps(message, indent=2) + "\n" + "----------------" + "\n" self.log.info( - f"=== Consume ====\n{header}{message}\nReceived message @{t:10.3f} ms" + f"=== Consume ====\n{header_str}{message}\nReceived message @{t:10.3f} ms" ) time.sleep(0.1) diff --git a/src/workflows/transport/common_transport.py b/src/workflows/transport/common_transport.py index cc707e99..d10012f8 100644 --- a/src/workflows/transport/common_transport.py +++ b/src/workflows/transport/common_transport.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import decimal import logging -from typing import Any, Callable, Dict, Mapping, Set +from typing import Any, Callable, Dict, Mapping, Optional, Set import workflows @@ -12,10 +14,10 @@ class CommonTransport: subscriptions and transactions.""" __callback_interceptor = None - __subscriptions: Dict[Any, Any] = {} - __subscription_id = 0 - __transactions: Set[Any] = set() - __transaction_id = 0 + __subscriptions: Dict[int, Dict[str, Any]] = {} + __subscription_id: int = 0 + __transactions: Set[int] = set() + __transaction_id: int = 0 log = logging.getLogger("workflows.transport") @@ -44,7 +46,7 @@ def disconnect(self): """Gracefully disconnect the transport class. This function should be overridden.""" - def subscribe(self, channel, callback, **kwargs): + def subscribe(self, channel, callback, **kwargs) -> int: """Listen to a queue, notify via callback function. :param channel: Queue name to subscribe to :param callback: Function to be called when messages are received. @@ -76,7 +78,7 @@ def mangled_callback(header, message): self._subscribe(self.__subscription_id, channel, mangled_callback, **kwargs) return self.__subscription_id - def unsubscribe(self, subscription, drop_callback_reference=False, **kwargs): + def unsubscribe(self, subscription: int, drop_callback_reference=False, **kwargs): """Stop listening to a queue or a broadcast :param subscription: Subscription ID to cancel :param drop_callback_reference: Drop the reference to the registered @@ -98,7 +100,7 @@ def unsubscribe(self, subscription, drop_callback_reference=False, **kwargs): if drop_callback_reference: self.drop_callback_reference(subscription) - def drop_callback_reference(self, subscription): + def drop_callback_reference(self, subscription: int): """Drop reference to the callback function after unsubscribing. Any future messages arriving for that subscription will result in exceptions being raised. @@ -114,7 +116,7 @@ def drop_callback_reference(self, subscription): ) del self.__subscriptions[subscription] - def subscribe_broadcast(self, channel, callback, **kwargs): + def subscribe_broadcast(self, channel, callback, **kwargs) -> int: """Listen to a broadcast topic, notify via callback function. :param channel: Topic name to subscribe to :param callback: Function to be called when messages are received. @@ -150,7 +152,7 @@ def mangled_callback(header, message): ) return self.__subscription_id - def subscription_callback(self, subscription) -> MessageCallback: + def subscription_callback(self, subscription: int) -> MessageCallback: """Retrieve the callback function for a subscription. Raise a workflows.Error if the subscription does not exist. All transport callbacks can be intercepted by setting an @@ -232,7 +234,7 @@ def raw_broadcast(self, destination, message, **kwargs): """ self._broadcast(destination, message, **kwargs) - def ack(self, message, subscription_id=None, **kwargs): + def ack(self, message, subscription_id: Optional[int] = None, **kwargs): """Acknowledge receipt of a message. This only makes sense when the 'acknowledgement' flag was set for the relevant subscription. :param message: ID of the message to be acknowledged, OR a dictionary @@ -259,7 +261,7 @@ def ack(self, message, subscription_id=None, **kwargs): ) self._ack(message_id, subscription_id=subscription_id, **kwargs) - def nack(self, message, subscription_id=None, **kwargs): + def nack(self, message, subscription_id: Optional[int] = None, **kwargs): """Reject receipt of a message. This only makes sense when the 'acknowledgement' flag was set for the relevant subscription. :param message: ID of the message to be rejected, OR a dictionary @@ -282,11 +284,11 @@ def nack(self, message, subscription_id=None, **kwargs): if not subscription_id: raise workflows.Error("Cannot reject message without subscription ID") self.log.debug( - "Rejecting message %s on subscription %s", message_id, subscription_id + "Rejecting message %s on subscription %d", message_id, subscription_id ) self._nack(message_id, subscription_id=subscription_id, **kwargs) - def transaction_begin(self, **kwargs): + def transaction_begin(self, **kwargs) -> int: """Start a new transaction. :param **kwargs: Further parameters for the transport layer. For example :return: A transaction ID that can be passed to other functions. @@ -297,7 +299,7 @@ def transaction_begin(self, **kwargs): self._transaction_begin(self.__transaction_id, **kwargs) return self.__transaction_id - def transaction_abort(self, transaction_id, **kwargs): + def transaction_abort(self, transaction_id: int, **kwargs): """Abort a transaction and roll back all operations. :param transaction_id: ID of transaction to be aborted. :param **kwargs: Further parameters for the transport layer. @@ -308,7 +310,7 @@ def transaction_abort(self, transaction_id, **kwargs): self.__transactions.remove(transaction_id) self._transaction_abort(transaction_id, **kwargs) - def transaction_commit(self, transaction_id, **kwargs): + def transaction_commit(self, transaction_id: int, **kwargs): """Commit a transaction. :param transaction_id: ID of transaction to be committed. :param **kwargs: Further parameters for the transport layer. @@ -341,7 +343,7 @@ def _subscribe(self, sub_id: int, channel, callback, **kwargs): """ raise NotImplementedError("Transport interface not implemented") - def _subscribe_broadcast(self, sub_id, channel, callback, **kwargs): + def _subscribe_broadcast(self, sub_id: int, channel, callback, **kwargs): """Listen to a broadcast topic, notify via callback function. :param sub_id: ID for this subscription in the transport layer :param channel: Topic name to subscribe to @@ -351,7 +353,7 @@ def _subscribe_broadcast(self, sub_id, channel, callback, **kwargs): """ raise NotImplementedError("Transport interface not implemented") - def _unsubscribe(self, sub_id): + def _unsubscribe(self, sub_id: int, **kwargs): """Stop listening to a queue or a broadcast :param sub_id: ID for this subscription in the transport layer """ diff --git a/src/workflows/transport/pika_transport.py b/src/workflows/transport/pika_transport.py index bd853c59..e310f67e 100644 --- a/src/workflows/transport/pika_transport.py +++ b/src/workflows/transport/pika_transport.py @@ -9,8 +9,6 @@ import sys import threading import time -import uuid -from collections.abc import Hashable from concurrent.futures import Future from enum import Enum, auto from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union @@ -319,14 +317,14 @@ def broadcast_status(self, status): def _call_message_callback( self, - sub_id: str, + subscription_id: int, _channel: pika.channel.Channel, method: pika.spec.Basic.Deliver, properties: pika.spec.BasicProperties, body: bytes, ): """Rewrite and redirect a pika callback to the subscription function""" - self.subscription_callback(sub_id)( + self.subscription_callback(subscription_id)( { "consumer_tag": str(method.consumer_tag), "delivery_mode": properties.delivery_mode, @@ -335,6 +333,7 @@ def _call_message_callback( "message-id": method.delivery_tag, "redelivered": method.redelivered, "routing_key": method.routing_key, + "subscription": subscription_id, }, body, ) @@ -391,7 +390,7 @@ def _subscribe( callback=functools.partial(self._call_message_callback, sub_id), auto_ack=not acknowledgement, exclusive=exclusive, - consumer_tag=str(sub_id), + subscription_id=sub_id, reconnectable=reconnectable, prefetch_count=prefetch_count, ).result() @@ -403,7 +402,7 @@ def _subscribe( def _subscribe_broadcast( self, - sub_id: Hashable, + sub_id: int, channel: str, callback: MessageCallback, *, @@ -429,15 +428,15 @@ def _subscribe_broadcast( self._pika_thread.subscribe_broadcast( exchange=channel, callback=functools.partial(self._call_message_callback, sub_id), - consumer_tag=str(sub_id), + subscription_id=sub_id, reconnectable=reconnectable, ).result() - def _unsubscribe(self, consumer_tag): + def _unsubscribe(self, sub_id: int, **kwargs): """Stop listening to a queue - :param consumer_tag: Consumer Tag to cancel + :param sub_id: Consumer Tag to cancel """ - self._pika_thread.unsubscribe(str(consumer_tag)) + self._pika_thread.unsubscribe(sub_id) # self._channel.basic_cancel(consumer_tag=consumer_tag, callback=None) # Callback reference is kept as further messages may already have been received @@ -539,7 +538,7 @@ def _transaction_commit(self, **kwargs): # self._channel.tx_commit() def _ack( - self, message_id, subscription_id: str, *, multiple: bool = False, **_kwargs + self, message_id, subscription_id: int, *, multiple: bool = False, **_kwargs ): """ Acknowledge receipt of a message. @@ -560,7 +559,7 @@ def _ack( def _nack( self, message_id, - subscription_id: str, + subscription_id: int, *, multiple: bool = False, requeue: bool = True, @@ -694,12 +693,14 @@ def __init__( name="workflows pika_transport", daemon=False, target=self._run ) self._state: _PikaThreadStatus = _PikaThreadStatus.NEW - # Internal store of subscriptions, to resubscribe if necessary - self._subscriptions: Dict[str, _PikaSubscription] = {} + # Internal store of subscriptions, to resubscribe if necessary. Keys are + # unique and auto-generated, and known as subscription IDs or consumer tags + # (strictly: pika/AMQP consumer tags are strings, not integers) + self._subscriptions: Dict[int, _PikaSubscription] = {} # The pika connection object self._connection: Optional[pika.BlockingConnection] = None # Per-subscription channels. May be pointing to the shared channel - self._pika_channels: Dict[str, BlockingChannel] = {} + self._pika_channels: Dict[int, BlockingChannel] = {} # A common, shared channel, used for non-QoS subscriptions self._pika_shared_channel: Optional[BlockingChannel] # Are we allowed to reconnect. Can only be turned off, never on @@ -798,9 +799,9 @@ def subscribe_queue( self, queue: str, callback: PikaCallback, + subscription_id: int, *, auto_ack: bool = True, - consumer_tag: Optional[str] = None, exclusive: bool = False, prefetch_count: int = 1, reconnectable: bool = False, @@ -809,9 +810,9 @@ def subscribe_queue( Subscribe to a queue. Thread-safe. Args: - consumer_tag: Internal ID representing this subscription queue: The queue to listen for messages on callback: The function to call when receiving messages on this queue + subscription_id: Internal ID representing this subscription. auto_ack: Should this subscription auto-acknowledge messages? exclusive: Should we be the only consumer? prefetch_count: How many messages are we allowed to prefetch @@ -825,10 +826,6 @@ def subscribe_queue( if not self._connection: raise RuntimeError("Cannot subscribe to unstarted connection") - # Safety: Since our Ack interface doesn't ask consumer ID yet, we can't ack - if not auto_ack and prefetch_count != 0: - raise ValueError("Cannot turn on manual acknowledgements with prefetch > 0") - new_sub = _PikaSubscription( arguments={}, auto_ack=auto_ack, @@ -842,7 +839,7 @@ def subscribe_queue( result: Future[None] = Future() self._connection.add_callback_threadsafe( functools.partial( - self._add_subscription_in_thread, consumer_tag, new_sub, result + self._add_subscription_in_thread, subscription_id, new_sub, result ) ) return result @@ -851,9 +848,9 @@ def subscribe_broadcast( self, exchange: str, callback: PikaCallback, + subscription_id: int, *, auto_ack: bool = True, - consumer_tag: Optional[str] = None, reconnectable: bool = False, prefetch_count: int = 0, ) -> Future[None]: @@ -864,7 +861,7 @@ def subscribe_broadcast( exchange: The queue to listen for messages on callback: The function to call when receiving messages on this queue auto_ack: Should this subscription auto-acknowledge messages? - consumer_tag: Internal ID representing this subscription. Generated if unspecified. + subscription_id: Internal ID representing this subscription. prefetch_count: How many messages are we allowed to prefetch reconnectable: Are we allowed to reconnect to this subscription? @@ -891,14 +888,16 @@ def subscribe_broadcast( result: Future[None] = Future() self._connection.add_callback_threadsafe( functools.partial( - self._add_subscription_in_thread, consumer_tag, new_sub, result + self._add_subscription_in_thread, subscription_id, new_sub, result ) ) return result - def unsubscribe(self, consumer_tag: str) -> Future[None]: - if consumer_tag not in self._subscriptions: - raise KeyError(f"No such subscription with consumer tag '{consumer_tag}'") + def unsubscribe(self, subscription_id: int) -> Future[None]: + if subscription_id not in self._subscriptions: + raise KeyError( + f"No subscription with ID {subscription_id} to unsubscribe from" + ) assert self._connection is not None @@ -907,10 +906,10 @@ def unsubscribe(self, consumer_tag: str) -> Future[None]: def _unsubscribe(): try: if result.set_running_or_notify_cancel(): - logger.debug("Unsubscribing consumer tag '%s'", consumer_tag) - del self._subscriptions[consumer_tag] - channel = self._pika_channels.pop(consumer_tag) - channel.basic_cancel(str(consumer_tag)) + logger.debug("Unsubscribing from subscription %d", subscription_id) + del self._subscriptions[subscription_id] + channel = self._pika_channels.pop(subscription_id) + channel.basic_cancel(str(subscription_id)) # Close the channel if nobody else is using it if channel not in self._pika_channels.values(): @@ -958,8 +957,7 @@ def _send(): self._connection.add_callback_threadsafe(_send) return future - def ack(self, delivery_tag: int, subscription_id: str, *, multiple=False): - subscription_id = str(subscription_id) + def ack(self, delivery_tag: int, subscription_id: int, *, multiple=False): if subscription_id not in self._subscriptions: raise KeyError(f"Could not find subscription {subscription_id} to ACK") @@ -972,9 +970,8 @@ def ack(self, delivery_tag: int, subscription_id: str, *, multiple=False): ) def nack( - self, delivery_tag: int, subscription_id: str, *, multiple=False, requeue=True + self, delivery_tag: int, subscription_id: int, *, multiple=False, requeue=True ): - subscription_id = str(subscription_id) if subscription_id not in self._subscriptions: raise KeyError(f"Could not find subscription {subscription_id} to NACK") @@ -1039,8 +1036,8 @@ def _recreate_subscriptions(self): logger.debug("Setting up %d subscriptions", len(old_subscriptions)) try: - for consumer_id, subscription in old_subscriptions.items(): - self._add_subscription(consumer_id, subscription) + for subscription_id, subscription in old_subscriptions.items(): + self._add_subscription(subscription_id, subscription) except BaseException: # If something goes (temporarily) wrong recreating, then we # don't want to only partially resubscribe next time @@ -1051,15 +1048,15 @@ def _recreate_subscriptions(self): f"Subscriptions recreated. Reconnections allowed? - {'Yes' if self._reconnection_allowed else 'No.'}" ) - def _add_subscription(self, consumer_tag: str, subscription: _PikaSubscription): + def _add_subscription(self, subscription_id: int, subscription: _PikaSubscription): assert self._connection is not None - assert consumer_tag not in self._subscriptions + assert subscription_id not in self._subscriptions # We flip reconnection to False if any subscription is not reconnectable if self._reconnection_allowed and not subscription.reconnectable: self._reconnection_allowed = False logger.debug( - f"Subscription {consumer_tag} to '{subscription.destination}' is not reconnectable. Turning reconnection off." + f"Subscription {subscription_id} to '{subscription.destination}' is not reconnectable. Turning reconnection off." ) # Either open a channel (if prefetch) or use the shared one @@ -1086,12 +1083,12 @@ def _add_subscription(self, consumer_tag: str, subscription: _PikaSubscription): subscription.on_message_callback, auto_ack=subscription.auto_ack, exclusive=subscription.exclusive, - consumer_tag=consumer_tag, + consumer_tag=str(subscription_id), ) # Only now we have subscribed successfully, add to the list - self._pika_channels[consumer_tag] = channel - self._subscriptions[consumer_tag] = subscription - logger.debug("Consuming (%s) on %s", consumer_tag, subscription.queue) + self._pika_channels[subscription_id] = channel + self._subscriptions[subscription_id] = subscription + logger.debug("Consuming (%d) on %s", subscription_id, subscription.queue) def _run(self): if self._please_stop.is_set(): @@ -1214,7 +1211,7 @@ def _run(self): def _add_subscription_in_thread( self, - consumer_tag: Optional[str], + subscription_id: int, subscription: _PikaSubscription, result: Future, ): @@ -1225,14 +1222,11 @@ def _add_subscription_in_thread( """ try: if result.set_running_or_notify_cancel(): - # If not specified, generate a consumer_tag automatically - if consumer_tag is None: - consumer_tag = str(uuid.uuid4()) assert ( - consumer_tag not in self._subscriptions - ), f"Subscription request {consumer_tag} rejected due to existing subscription {self._subscriptions[consumer_tag]}" - self._add_subscription(consumer_tag, subscription) - result.set_result(self._subscriptions[consumer_tag].queue) + subscription_id not in self._subscriptions + ), f"Subscription request {subscription_id} rejected due to existing subscription {self._subscriptions[subscription_id]}" + self._add_subscription(subscription_id, subscription) + result.set_result(self._subscriptions[subscription_id].queue) except BaseException as e: result.set_exception(e) raise diff --git a/tests/transport/test_pika.py b/tests/transport/test_pika.py index 216eb33e..41016dac 100644 --- a/tests/transport/test_pika.py +++ b/tests/transport/test_pika.py @@ -596,7 +596,7 @@ def test_subscribe_to_queue(mock_pikathread): assert kwargs == { "auto_ack": True, "callback": mock.ANY, - "consumer_tag": "1", + "subscription_id": 1, "exclusive": False, "prefetch_count": 1, "queue": str(mock.sentinel.queue1), @@ -611,7 +611,7 @@ def test_subscribe_to_queue(mock_pikathread): assert kwargs == { "auto_ack": True, "callback": mock.ANY, - "consumer_tag": "2", + "subscription_id": 2, "exclusive": False, "prefetch_count": 1, "queue": str(mock.sentinel.queue2), @@ -625,7 +625,7 @@ def test_subscribe_to_queue(mock_pikathread): assert kwargs == { "auto_ack": False, "callback": mock.ANY, - "consumer_tag": "3", + "subscription_id": 3, "exclusive": False, "prefetch_count": 1, "queue": str(mock.sentinel.queue3), @@ -633,9 +633,9 @@ def test_subscribe_to_queue(mock_pikathread): } transport._unsubscribe(1) - mock_pikathread.unsubscribe.assert_called_once_with("1") + mock_pikathread.unsubscribe.assert_called_once_with(1) transport._unsubscribe(2) - mock_pikathread.unsubscribe.assert_called_with("2") + mock_pikathread.unsubscribe.assert_called_with(2) def test_subscribe_to_broadcast(mock_pikathread): @@ -652,7 +652,7 @@ def test_subscribe_to_broadcast(mock_pikathread): assert kwargs == { "callback": mock.ANY, "exchange": str(mock.sentinel.queue1), - "consumer_tag": "1", + "subscription_id": 1, "reconnectable": False, } @@ -668,14 +668,14 @@ def test_subscribe_to_broadcast(mock_pikathread): assert kwargs == { "callback": mock.ANY, "exchange": str(mock.sentinel.queue2), - "consumer_tag": "2", + "subscription_id": 2, "reconnectable": False, } transport._unsubscribe(1) - mock_pikathread.unsubscribe.assert_called_once_with("1") + mock_pikathread.unsubscribe.assert_called_once_with(1) transport._unsubscribe(2) - mock_pikathread.unsubscribe.assert_called_with("2") + mock_pikathread.unsubscribe.assert_called_with(2) @mock.patch("workflows.transport.pika_transport.pika") @@ -882,7 +882,7 @@ def _callback(channel, basic_deliver, properties, body): # Make a subscription and wait for it to be valid thread.subscribe_broadcast( - exchange, _callback, consumer_tag=0, reconnectable=True + exchange, _callback, subscription_id=1, reconnectable=True ).result() test_channel.basic_publish(exchange, routing_key="", body="A Message") @@ -907,7 +907,9 @@ def _got_message(*args): got_message.set() exchange = test_channel.temporary_exchange_declare(exchange_type="fanout") - thread.subscribe_broadcast(exchange, _got_message, reconnectable=True).result() + thread.subscribe_broadcast( + exchange, _got_message, reconnectable=True, subscription_id=1 + ).result() # Force reconnection - normally we want this to be transparent, but # let's twiddle the internals so we can wait for reconnection as we @@ -929,7 +931,10 @@ def _got_message_2(*args): got_message_2.set() thread.subscribe_broadcast( - exchange, _got_message_2, reconnectable=False + exchange, + _got_message_2, + reconnectable=False, + subscription_id=2, ).result() # Make sure that the thread ends instead of reconnect if we force a disconnection @@ -953,7 +958,9 @@ def _get_message(*args): print(f"Got message: {pprint.pformat(args)}") messages.put(args[3]) - thread.subscribe_queue(queue, _get_message, reconnectable=True) + thread.subscribe_queue( + queue, _get_message, reconnectable=True, subscription_id=1 + ) test_channel.basic_publish("", queue, "This is a message") assert messages.get(timeout=2) == b"This is a message" @@ -1042,13 +1049,13 @@ def _get_message(*args): messages.put(args[3]) thread.subscribe_queue( - queue, _get_message, reconnectable=True, consumer_tag="1" + queue, _get_message, reconnectable=True, subscription_id=1 ) test_channel.basic_publish("", queue, "This is a message") assert messages.get(timeout=1) == b"This is a message" # Issue an unsubscribe then wait for confirmation - thread.unsubscribe("1").result() + thread.unsubscribe(1).result() # Send a message again test_channel.basic_publish("", queue, "This is a message again")