diff --git a/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py b/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py index 1b4f61c0..012f88fd 100644 --- a/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py +++ b/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py @@ -208,15 +208,28 @@ def _task(connections_list: list): if conn_reference is None: continue + # For pool-proxied connections (e.g. SQLAlchemy's + # _ConnectionFairy), prefer .invalidate() over .close(). + # close() checks the connection back into the pool, which first + # runs rollback-on-return against the failed writer -- an + # unbounded blocking call on this thread when the host is + # unreachable -- and re-pools the connection if that rollback + # happens to succeed (e.g. the old writer came back as a + # reader). invalidate() skips the rollback and discards the + # connection so the pool opens a fresh one on next checkout. try: - conn_reference.close() + inv = getattr(conn_reference, "invalidate", None) + if callable(inv): + inv() + else: + conn_reference.close() except Exception: # Swallow this exception, current connection should be useless anyway pass def _invalidate_connections(self, connections_list: list): invalidate_connection_thread: Thread = Thread(daemon=True, target=self._task, - args=[connections_list]) # type: ignore + args=[connections_list]) # type: ignore[arg-type] invalidate_connection_thread.start() def log_opened_connections(self): diff --git a/aws_advanced_python_wrapper/failover_plugin.py b/aws_advanced_python_wrapper/failover_plugin.py index 062000c2..99157f65 100644 --- a/aws_advanced_python_wrapper/failover_plugin.py +++ b/aws_advanced_python_wrapper/failover_plugin.py @@ -375,13 +375,13 @@ def _invalidate_current_connection(self): if self._plugin_service.is_in_transaction: self._plugin_service.update_in_transaction(True) try: - driver_dialect.execute(DbApiMethod.CONNECTION_ROLLBACK.method_name, lambda: conn.rollback()) + driver_dialect.execute(DbApiMethod.CONNECTION_ROLLBACK.method_name, lambda: conn.rollback(), conn=conn) conn.rollback() except Exception: pass try: - return driver_dialect.execute(DbApiMethod.CONNECTION_CLOSE.method_name, lambda: conn.close()) + return driver_dialect.execute(DbApiMethod.CONNECTION_CLOSE.method_name, lambda: conn.close(), conn=conn) except Exception: pass diff --git a/aws_advanced_python_wrapper/failover_v2_plugin.py b/aws_advanced_python_wrapper/failover_v2_plugin.py index 9d443b89..26070636 100644 --- a/aws_advanced_python_wrapper/failover_v2_plugin.py +++ b/aws_advanced_python_wrapper/failover_v2_plugin.py @@ -281,7 +281,7 @@ def _get_reader_failover_connection(self) -> ReaderFailoverResult: remaining_readers.remove(reader_candidate) self._plugin_service.driver_dialect.execute( - DbApiMethod.CONNECTION_CLOSE.method_name, lambda: candidate_conn.close()) + DbApiMethod.CONNECTION_CLOSE.method_name, lambda: candidate_conn.close(), conn=candidate_conn) if role == HostRole.WRITER: reader_candidates.remove(reader_candidate) @@ -301,7 +301,7 @@ def _get_reader_failover_connection(self) -> ReaderFailoverResult: return ReaderFailoverResult(candidate_conn, updated_host_info) self._plugin_service.driver_dialect.execute( - DbApiMethod.CONNECTION_CLOSE.method_name, lambda: candidate_conn.close()) + DbApiMethod.CONNECTION_CLOSE.method_name, lambda: candidate_conn.close(), conn=candidate_conn) if role == HostRole.WRITER: is_original_writer_still_writer = True except Exception: @@ -381,7 +381,7 @@ def _invalidate_current_connection(self) -> None: try: self._plugin_service.driver_dialect.execute( - DbApiMethod.CONNECTION_CLOSE.method_name, lambda: conn.close()) + DbApiMethod.CONNECTION_CLOSE.method_name, lambda: conn.close(), conn=conn) except Exception: pass diff --git a/tests/unit/test_aurora_connection_tracker.py b/tests/unit/test_aurora_connection_tracker.py index 50697599..dac80ae3 100644 --- a/tests/unit/test_aurora_connection_tracker.py +++ b/tests/unit/test_aurora_connection_tracker.py @@ -18,6 +18,7 @@ import pytest from _weakrefset import WeakSet +from sqlalchemy.pool import QueuePool from aws_advanced_python_wrapper.errors import FailoverError @@ -176,3 +177,49 @@ def test_invalidate_all_connections_drains_set(mocker): tracker.invalidate_all_connections(host_info=host_info) assert len(captured) == 1 + + +def test_task_invalidates_pool_proxied_connections(mocker): + pool_proxied_conn = mocker.MagicMock() + + OpenedConnectionTracker._task([pool_proxied_conn]) + + pool_proxied_conn.invalidate.assert_called_once_with() + pool_proxied_conn.close.assert_not_called() + + +def test_task_closes_plain_connections_without_invalidate(mocker): + plain_conn = mocker.MagicMock() + del plain_conn.invalidate + + OpenedConnectionTracker._task([plain_conn]) + + plain_conn.close.assert_called_once_with() + + +def test_task_discards_pooled_connection_from_queue_pool(mocker): + # End-to-end against a real SQLAlchemy QueuePool: invalidating the + # checked-out fairy must discard the underlying driver connection without + # running rollback-on-return against the failed host, so the pool opens a + # fresh connection on the next checkout instead of re-pooling the old one. + raw_connections = [] + + def creator(): + raw_conn = mocker.MagicMock() + raw_connections.append(raw_conn) + return raw_conn + + queue_pool = QueuePool(creator, pool_size=1, max_overflow=0) + fairy = queue_pool.connect() + + OpenedConnectionTracker._task([fairy]) + + assert len(raw_connections) == 1 + raw_connections[0].rollback.assert_not_called() + raw_connections[0].close.assert_called_once() + + replacement = queue_pool.connect() + assert len(raw_connections) == 2 + assert replacement.driver_connection is raw_connections[1] + replacement.close() + queue_pool.dispose()