Skip to content

Commit be6d994

Browse files
authored
Route outgoing messages to correct channel (#99)
If it is part of a transaction then use the relevant transacted channel, otherwise use the (untransacted) shared channel. Resolves #98
1 parent 8d883e7 commit be6d994

2 files changed

Lines changed: 41 additions & 13 deletions

File tree

src/workflows/transport/pika_transport.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,14 @@ def _unsubscribe(self, sub_id: int, **kwargs):
438438
# Callback reference is kept as further messages may already have been received
439439

440440
def _send(
441-
self, destination, message, headers=None, delay=None, expiration=None, **kwargs
441+
self,
442+
destination,
443+
message,
444+
headers=None,
445+
delay=None,
446+
expiration=None,
447+
transaction: Optional[int] = None,
448+
**kwargs,
442449
):
443450
"""
444451
Send a message to a queue.
@@ -449,6 +456,7 @@ def _send(
449456
headers: Further arbitrary headers to pass to pika
450457
delay: Delay transport of message by this many seconds
451458
expiration: Optional TTL expiration time, relative to sending time
459+
transaction: Transaction ID if message should be part of a transaction
452460
"""
453461
if not headers:
454462
headers = {}
@@ -470,6 +478,7 @@ def _send(
470478
body=message,
471479
properties=properties,
472480
mandatory=True,
481+
transaction_id=transaction,
473482
).result()
474483

475484
def _broadcast(
@@ -479,6 +488,7 @@ def _broadcast(
479488
headers=None,
480489
delay=None,
481490
expiration: Optional[int] = None,
491+
transaction: Optional[int] = None,
482492
**kwargs,
483493
):
484494
"""Send a message to a fanout exchange.
@@ -489,6 +499,7 @@ def _broadcast(
489499
headers: Further arbitrary headers to pass to pika
490500
delay: Delay transport of message by this many seconds
491501
expiration: Optional TTL expiration time, in seconds, relative to sending time
502+
transaction: Transaction ID if message should be part of a transaction
492503
kwargs: Arbitrary arguments for other transports. Ignored.
493504
"""
494505
assert delay is None, "Delay Not implemented"
@@ -511,6 +522,7 @@ def _broadcast(
511522
body=message,
512523
properties=properties,
513524
mandatory=False,
525+
transaction_id=transaction,
514526
).result()
515527

516528
def _transaction_begin(
@@ -930,6 +942,7 @@ def send(
930942
body: Union[str, bytes],
931943
properties: pika.spec.BasicProperties = None,
932944
mandatory: bool = True,
945+
transaction_id: Optional[int] = None,
933946
) -> Future[None]:
934947
"""Send a message. Thread-safe."""
935948

@@ -941,7 +954,11 @@ def send(
941954
def _send():
942955
if future.set_running_or_notify_cancel():
943956
try:
944-
self._get_shared_channel().basic_publish(
957+
if transaction_id:
958+
channel = self._transactions_by_id[transaction_id]
959+
else:
960+
channel = self._get_shared_channel()
961+
channel.basic_publish(
945962
exchange=exchange,
946963
routing_key=routing_key,
947964
body=body,

tests/transport/test_pika.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,16 @@ def test_check_config_file_behaviour(mockpika, mock_pikathread, tmp_path):
9393
cfgfile = tmp_path / "config"
9494
cfgfile.write_text(
9595
"""
96-
# An example pika configuration file
97-
# Only lines in the [pika] block will be interpreted
98-
99-
[rabbit]
100-
host = localhost
101-
port = 5672
102-
username = someuser
103-
password = somesecret
104-
vhost = namespace
105-
"""
96+
# An example pika configuration file
97+
# Only lines in the [pika] block will be interpreted
98+
99+
[rabbit]
100+
host = localhost
101+
port = 5672
102+
username = someuser
103+
password = somesecret
104+
vhost = namespace
105+
"""
106106
)
107107

108108
parser.parse_args(
@@ -218,6 +218,7 @@ def test_broadcast_status(mockpika, mock_pikathread):
218218
"body": mock.ANY,
219219
"properties": mock.ANY,
220220
"mandatory": False,
221+
"transaction_id": None,
221222
}
222223
statusdict = json.loads(kwargs.get("body"))
223224
assert statusdict["status"] == str(mock.sentinel.status)
@@ -242,6 +243,7 @@ def test_send_message(mockpika, mock_pikathread):
242243
"body": mock.sentinel.message,
243244
"mandatory": True,
244245
"properties": mock.ANY,
246+
"transaction_id": None,
245247
}
246248
assert mockproperties.call_args[1].get("headers") == {}
247249
assert int(mockproperties.call_args[1].get("delivery_mode")) == 2
@@ -291,6 +293,7 @@ def test_sending_message_with_expiration(mockpika, mock_pikathread):
291293
"body": mock.sentinel.message,
292294
"mandatory": True,
293295
"properties": mock.ANY,
296+
"transaction_id": None,
294297
}
295298
assert int(mockproperties.return_value.expiration) == 120 * 1000
296299

@@ -341,6 +344,7 @@ def test_send_broadcast(mockpika, mock_pikathread):
341344
"body": mock.sentinel.message,
342345
"properties": mock.ANY,
343346
"mandatory": False,
347+
"transaction_id": None,
344348
}
345349

346350
transport._broadcast(
@@ -360,6 +364,7 @@ def test_send_broadcast(mockpika, mock_pikathread):
360364
"body": mock.sentinel.message,
361365
"properties": mock.ANY,
362366
"mandatory": False,
367+
"transaction_id": None,
363368
}
364369

365370
# Delay not implemented yet
@@ -404,6 +409,7 @@ def test_broadcasting_message_with_expiration(mockpika, mock_pikathread):
404409
"body": mock.sentinel.message,
405410
"properties": mock.ANY,
406411
"mandatory": False,
412+
"transaction_id": None,
407413
}
408414

409415

@@ -447,6 +453,7 @@ def test_messages_are_serialized_for_transport(mock_pikathread):
447453
"body": banana_str,
448454
"properties": pika.BasicProperties(delivery_mode=2, headers={}),
449455
"mandatory": True,
456+
"transaction_id": None,
450457
}
451458

452459
transport.broadcast(str(mock.sentinel.queue2), banana)
@@ -458,6 +465,7 @@ def test_messages_are_serialized_for_transport(mock_pikathread):
458465
"body": banana_str,
459466
"properties": pika.BasicProperties(delivery_mode=2, headers={}),
460467
"mandatory": False,
468+
"transaction_id": None,
461469
}
462470

463471
with pytest.raises(TypeError):
@@ -481,6 +489,7 @@ def test_messages_are_not_serialized_for_raw_transport(_mockpika, mock_pikathrea
481489
"body": banana,
482490
"mandatory": True,
483491
"properties": mock.ANY,
492+
"transaction_id": None,
484493
}
485494

486495
mock_pikathread.send.reset_mock()
@@ -494,6 +503,7 @@ def test_messages_are_not_serialized_for_raw_transport(_mockpika, mock_pikathrea
494503
"body": banana,
495504
"properties": mock.ANY,
496505
"mandatory": False,
506+
"transaction_id": None,
497507
}
498508

499509
mock_pikathread.send.reset_mock()
@@ -507,6 +517,7 @@ def test_messages_are_not_serialized_for_raw_transport(_mockpika, mock_pikathrea
507517
"body": mock.sentinel.unserializable,
508518
"mandatory": True,
509519
"properties": mock.ANY,
520+
"transaction_id": None,
510521
}
511522

512523

@@ -818,7 +829,6 @@ def temporary_queue_declare(
818829
):
819830
"""
820831
Declare an auto-named queue that is automatically deleted on test end.
821-
822832
"""
823833
queue = self.channel.queue_declare(
824834
"", auto_delete=auto_delete, exclusive=exclusive, **kwargs
@@ -994,6 +1004,7 @@ def test_pikathread_send(connection_params, test_channel):
9941004
).result()
9951005

9961006
# But should fail with it declared
1007+
pytest.xfail("UnroutableError is not raised without publisher confirms, #96")
9971008
with pytest.raises(pika.exceptions.UnroutableError):
9981009
thread.send(
9991010
"", "unroutable-missing-queue", "Another Message", mandatory=True

0 commit comments

Comments
 (0)