Skip to content

Commit c75ba89

Browse files
authored
Merge 55a8dc9 into a92d7b3
2 parents a92d7b3 + 55a8dc9 commit c75ba89

6 files changed

Lines changed: 148 additions & 39 deletions

File tree

src/workflows/services/sample_transaction.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@ class SampleTxn(CommonService):
1717
def initializing(self):
1818
"""Subscribe to a channel. Received messages must be acknowledged."""
1919
self.subid = self._transport.subscribe(
20-
"transient.transaction", self.receive_message, acknowledgement=True
20+
"transient.transaction",
21+
self.receive_message,
22+
acknowledgement=True,
23+
prefetch_count=1000,
2124
)
2225

2326
@staticmethod

src/workflows/transport/common_transport.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -288,15 +288,24 @@ def nack(self, message, subscription_id: Optional[int] = None, **kwargs):
288288
)
289289
self._nack(message_id, subscription_id=subscription_id, **kwargs)
290290

291-
def transaction_begin(self, **kwargs) -> int:
291+
def transaction_begin(self, subscription_id: Optional[int] = None, **kwargs) -> int:
292292
"""Start a new transaction.
293-
:param **kwargs: Further parameters for the transport layer. For example
293+
:param **kwargs: Further parameters for the transport layer.
294294
:return: A transaction ID that can be passed to other functions.
295295
"""
296296
self.__transaction_id += 1
297297
self.__transactions.add(self.__transaction_id)
298-
self.log.debug("Starting transaction with ID %d", self.__subscription_id)
299-
self._transaction_begin(self.__transaction_id, **kwargs)
298+
if subscription_id:
299+
self.log.debug(
300+
"Starting transaction with ID %d on subscription %d",
301+
self.__transaction_id,
302+
subscription_id,
303+
)
304+
else:
305+
self.log.debug("Starting transaction with ID %d", self.__transaction_id)
306+
self._transaction_begin(
307+
self.__transaction_id, subscription_id=subscription_id, **kwargs
308+
)
300309
return self.__transaction_id
301310

302311
def transaction_abort(self, transaction_id: int, **kwargs):
@@ -405,21 +414,23 @@ def _nack(self, message_id, subscription_id, **kwargs):
405414
"""
406415
raise NotImplementedError("Transport interface not implemented")
407416

408-
def _transaction_begin(self, transaction_id, **kwargs):
417+
def _transaction_begin(
418+
self, transaction_id: int, *, subscription_id: Optional[int] = None, **kwargs
419+
) -> None:
409420
"""Start a new transaction.
410421
:param transaction_id: ID for this transaction in the transport layer.
411422
:param **kwargs: Further parameters for the transport layer.
412423
"""
413424
raise NotImplementedError("Transport interface not implemented")
414425

415-
def _transaction_abort(self, transaction_id, **kwargs):
426+
def _transaction_abort(self, transaction_id: int, **kwargs) -> None:
416427
"""Abort a transaction and roll back all operations.
417428
:param transaction_id: ID of transaction to be aborted.
418429
:param **kwargs: Further parameters for the transport layer.
419430
"""
420431
raise NotImplementedError("Transport interface not implemented")
421432

422-
def _transaction_commit(self, transaction_id, **kwargs):
433+
def _transaction_commit(self, transaction_id: int, **kwargs) -> None:
423434
"""Commit a transaction.
424435
:param transaction_id: ID of transaction to be committed.
425436
:param **kwargs: Further parameters for the transport layer.

src/workflows/transport/pika_transport.py

Lines changed: 124 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -513,26 +513,26 @@ def _broadcast(
513513
mandatory=False,
514514
).result()
515515

516-
def _transaction_begin(self, **kwargs):
517-
"""Enter transaction mode.
518-
:param **kwargs: Further parameters for the transport layer.
516+
def _transaction_begin(
517+
self, transaction_id: int, *, subscription_id: Optional[int] = None, **kwargs
518+
) -> None:
519+
"""Start a new transaction.
520+
:param transaction_id: ID for this transaction in the transport layer.
521+
:param subscription_id: Tie the transaction to a specific channel containing this subscription.
519522
"""
520-
raise NotImplementedError()
521-
# self._channel.tx_select()
523+
self._pika_thread.tx_select(transaction_id, subscription_id)
522524

523-
def _transaction_abort(self, **kwargs):
525+
def _transaction_abort(self, transaction_id: int, **kwargs) -> None:
524526
"""Abort a transaction and roll back all operations.
525-
:param **kwargs: Further parameters for the transport layer.
527+
:param transaction_id: ID of transaction to be aborted.
526528
"""
527-
raise NotImplementedError()
528-
# self._channel.tx_rollback()
529+
self._pika_thread.tx_rollback(transaction_id)
529530

530-
def _transaction_commit(self, **kwargs):
531+
def _transaction_commit(self, transaction_id: int, **kwargs) -> None:
531532
"""Commit a transaction.
532-
:param **kwargs: Further parameters for the transport layer.
533+
:param transaction_id: ID of transaction to be committed.
533534
"""
534-
raise NotImplementedError()
535-
# self._channel.tx_commit()
535+
self._pika_thread.tx_commit(transaction_id)
536536

537537
def _ack(
538538
self, message_id, subscription_id: int, *, multiple: bool = False, **_kwargs
@@ -694,9 +694,12 @@ def __init__(
694694
self._subscriptions: Dict[int, _PikaSubscription] = {}
695695
# The pika connection object
696696
self._connection: Optional[pika.BlockingConnection] = None
697-
# Per-subscription channels. May be pointing to the shared channel
697+
# Index of per-subscription channels.
698698
self._pika_channels: Dict[int, BlockingChannel] = {}
699-
# A common, shared channel, used for non-QoS subscriptions
699+
# Bidirectional index of all ongoing transactions. May include the shared channel
700+
self._transactions_by_id: Dict[int, BlockingChannel] = {}
701+
self._transactions_by_channel: Dict[BlockingChannel, int] = {}
702+
# A common, shared channel, used for sending messages outside of transactions.
700703
self._pika_shared_channel: Optional[BlockingChannel]
701704
# Are we allowed to reconnect. Can only be turned off, never on
702705
self._reconnection_allowed: bool = True
@@ -907,6 +910,11 @@ def _unsubscribe():
907910
logger.debug("Closing channel that is now unused")
908911
channel.close()
909912

913+
# Forget about any ongoing transactions on the channel
914+
if channel in self._transactions_by_channel:
915+
transaction_id = self._transactions_by_channel.pop(channel)
916+
self._transactions_by_id.pop(transaction_id)
917+
910918
result.set_result(None)
911919
except BaseException as e:
912920
result.set_exception(e)
@@ -974,6 +982,100 @@ def nack(
974982
lambda: channel.basic_nack(delivery_tag, multiple=multiple, requeue=requeue)
975983
)
976984

985+
def tx_select(
986+
self, transaction_id: int, subscription_id: Optional[int]
987+
) -> Future[None]:
988+
"""Set a channel to transaction mode. Thread-safe.
989+
:param transaction_id: ID for this transaction in the transport layer.
990+
:param subscription_id: Tie the transaction to a specific channel containing this subscription.
991+
"""
992+
993+
if not self._connection:
994+
raise RuntimeError("Cannot transact on unstarted connection")
995+
996+
future: Future[None] = Future()
997+
998+
def _tx_select():
999+
if future.set_running_or_notify_cancel():
1000+
try:
1001+
if subscription_id:
1002+
if subscription_id not in self._pika_channels:
1003+
raise KeyError(
1004+
f"Could not find subscription {subscription_id} to begin transaction"
1005+
)
1006+
channel = self._pika_channels[subscription_id]
1007+
else:
1008+
channel = self._get_shared_channel()
1009+
if channel in self._transactions_by_channel:
1010+
raise KeyError(
1011+
f"Channel {channel} is already running transaction {self._transactions_by_channel[channel]}, so can't start transaction {transaction_id}"
1012+
)
1013+
channel.tx_select()
1014+
self._transactions_by_channel[channel] = transaction_id
1015+
self._transactions_by_id[transaction_id] = channel
1016+
1017+
future.set_result(None)
1018+
except BaseException as e:
1019+
future.set_exception(e)
1020+
raise
1021+
1022+
self._connection.add_callback_threadsafe(_tx_select)
1023+
return future
1024+
1025+
def tx_rollback(self, transaction_id: int) -> Future[None]:
1026+
"""Abort a transaction and roll back all operations. Thread-safe.
1027+
:param transaction_id: ID of transaction to be aborted.
1028+
"""
1029+
if not self._connection:
1030+
raise RuntimeError("Cannot transact on unstarted connection")
1031+
1032+
future: Future[None] = Future()
1033+
1034+
def _tx_rollback():
1035+
if future.set_running_or_notify_cancel():
1036+
try:
1037+
channel = self._transactions_by_id.pop(transaction_id, None)
1038+
if not channel:
1039+
raise KeyError(
1040+
f"Could not find transaction {transaction_id} to roll back"
1041+
)
1042+
self._transactions_by_channel.pop(channel)
1043+
channel.tx_rollback()
1044+
future.set_result(None)
1045+
except BaseException as e:
1046+
future.set_exception(e)
1047+
raise
1048+
1049+
self._connection.add_callback_threadsafe(_tx_rollback)
1050+
return future
1051+
1052+
def tx_commit(self, transaction_id: int) -> Future[None]:
1053+
"""Commit a transaction.
1054+
:param transaction_id: ID of transaction to be committed. Thread-safe..
1055+
"""
1056+
if not self._connection:
1057+
raise RuntimeError("Cannot transact on unstarted connection")
1058+
1059+
future: Future[None] = Future()
1060+
1061+
def _tx_commit():
1062+
if future.set_running_or_notify_cancel():
1063+
try:
1064+
channel = self._transactions_by_id.pop(transaction_id, None)
1065+
if not channel:
1066+
raise KeyError(
1067+
f"Could not find transaction {transaction_id} to commit"
1068+
)
1069+
self._transactions_by_channel.pop(channel)
1070+
channel.tx_commit()
1071+
future.set_result(None)
1072+
except BaseException as e:
1073+
future.set_exception(e)
1074+
raise
1075+
1076+
self._connection.add_callback_threadsafe(_tx_commit)
1077+
return future
1078+
9771079
@property
9781080
def connection_alive(self) -> bool:
9791081
"""
@@ -989,7 +1091,7 @@ def connection_alive(self) -> bool:
9891091
)
9901092

9911093
# NOTE: With reconnection lifecycle this probably doesn't make sense
992-
# on it's own. It might make sense to add this returning a
1094+
# on its own. It might make sense to add this returning a
9931095
# connection-specific 'token' - presumably the user might want
9941096
# to ensure that a connection is still the same connection
9951097
# and thus adhering to various within-connection guarantees.
@@ -1017,7 +1119,7 @@ def _get_shared_channel(self) -> BlockingChannel:
10171119

10181120
if not self._pika_shared_channel:
10191121
self._pika_shared_channel = self._connection.channel()
1020-
self._pika_shared_channel.confirm_delivery()
1122+
##### self._pika_shared_channel.confirm_delivery()
10211123
return self._pika_shared_channel
10221124

10231125
def _recreate_subscriptions(self):
@@ -1050,22 +1152,18 @@ def _add_subscription(self, subscription_id: int, subscription: _PikaSubscriptio
10501152
f"Subscription {subscription_id} to '{subscription.destination}' is not reconnectable. Turning reconnection off."
10511153
)
10521154

1053-
# Either open a channel (if prefetch) or use the shared one
1054-
if subscription.prefetch_count == 0:
1055-
channel = self._get_shared_channel()
1056-
else:
1057-
channel = self._connection.channel()
1058-
channel.confirm_delivery()
1059-
channel.basic_qos(prefetch_count=subscription.prefetch_count)
1155+
# Open a dedicated channel for this subscription
1156+
channel = self._connection.channel()
1157+
channel.basic_qos(prefetch_count=subscription.prefetch_count)
10601158

1061-
if subscription.kind == _PikaSubscriptionKind.FANOUT:
1159+
if subscription.kind is _PikaSubscriptionKind.FANOUT:
10621160
# If a FANOUT subscription, then we need to create and bind
10631161
# a temporary queue to receive messages from the exchange
10641162
queue = channel.queue_declare("", exclusive=True).method.queue
10651163
assert queue is not None
10661164
channel.queue_bind(queue, subscription.destination)
10671165
subscription.queue = queue
1068-
elif subscription.kind == _PikaSubscriptionKind.DIRECT:
1166+
elif subscription.kind is _PikaSubscriptionKind.DIRECT:
10691167
subscription.queue = subscription.destination
10701168
else:
10711169
raise NotImplementedError(f"Unknown subscription kind: {subscription.kind}")

src/workflows/transport/stomp_transport.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,21 +410,18 @@ def _broadcast(
410410
def _transaction_begin(self, transaction_id, **kwargs):
411411
"""Start a new transaction.
412412
:param transaction_id: ID for this transaction in the transport layer.
413-
:param **kwargs: Further parameters for the transport layer.
414413
"""
415414
self._conn.begin(transaction=transaction_id)
416415

417416
def _transaction_abort(self, transaction_id, **kwargs):
418417
"""Abort a transaction and roll back all operations.
419418
:param transaction_id: ID of transaction to be aborted.
420-
:param **kwargs: Further parameters for the transport layer.
421419
"""
422420
self._conn.abort(transaction_id)
423421

424422
def _transaction_commit(self, transaction_id, **kwargs):
425423
"""Commit a transaction.
426424
:param transaction_id: ID of transaction to be committed.
427-
:param **kwargs: Further parameters for the transport layer.
428425
"""
429426
self._conn.commit(transaction_id)
430427

tests/services/test_sample_transaction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_txnservice_subscribes_to_channel():
5252
p.initializing()
5353

5454
mock_transport.subscribe.assert_called_once_with(
55-
mock.ANY, p.receive_message, acknowledgement=True
55+
mock.ANY, p.receive_message, acknowledgement=True, prefetch_count=1000
5656
)
5757

5858

tests/transport/test_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def test_create_and_destroy_transactions():
248248
t = ct.transaction_begin()
249249

250250
assert t
251-
ct._transaction_begin.assert_called_once_with(t)
251+
ct._transaction_begin.assert_called_once_with(t, subscription_id=None)
252252

253253
ct.transaction_abort(t)
254254
with pytest.raises(workflows.Error):

0 commit comments

Comments
 (0)