From 28be3343154e1844dc006135ec3e7ece740f64a2 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 18:06:34 -0400 Subject: [PATCH] cluster: fix host state pool refresh races --- cassandra/cluster.py | 24 ++++++++++----- tests/unit/test_cluster.py | 44 +++++++++++++++++++++++++++ tests/unit/test_control_connection.py | 29 ++++++++++++++++-- 3 files changed, 87 insertions(+), 10 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 5e7a68bc1c..5c72e26125 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1975,6 +1975,10 @@ def on_up(self, host): with host.lock: host.set_up() host._currently_handling_node_up = False + for listener in self.listeners: + listener.on_up(host) + for session in tuple(self.sessions): + session.update_created_pools() # for testing purposes return futures @@ -2020,7 +2024,7 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False): Intended for internal use only. """ if self.is_shutdown: - return + return False with host.lock: was_up = host.is_up @@ -2035,14 +2039,15 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False): if pool_state: connected |= pool_state['open_count'] > 0 if connected: - return + return False host.set_down() if (not was_up and not expect_host_to_be_down) or host.is_currently_reconnecting(): - return + return False log.warning("Host %s has been marked down", host) self.on_down_potentially_blocking(host, is_host_addition) + return True def on_add(self, host, refresh_nodes=True): if self.is_shutdown: @@ -2134,8 +2139,8 @@ def on_remove(self, host): def signal_connection_failure(self, host, connection_exc, is_host_addition, expect_host_to_be_down=False): is_down = host.signal_connection_failure(connection_exc) if is_down: - self.on_down(host, is_host_addition, expect_host_to_be_down) - return is_down + return self.on_down(host, is_host_addition, expect_host_to_be_down) + return False def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_nodes=True, host_id=None): """ @@ -3315,7 +3320,9 @@ def update_created_pools(self): # we don't eagerly set is_up on previously ignored hosts. None is included here # to allow us to attempt connections to hosts that have gone from ignored to something # else. - if distance != HostDistance.IGNORED and host.is_up in (True, None): + if (distance != HostDistance.IGNORED and + host.is_up in (True, None) and + not getattr(host, '_currently_handling_node_up', False)): future = self.add_or_renew_pool(host, False) elif distance != pool.host_distance: # the distance has changed @@ -4226,9 +4233,10 @@ def _signal_error(self): # host may be None if it's already been removed, but that indicates # that errors have already been reported, so we're fine if host: - self._cluster.signal_connection_failure( + is_down = self._cluster.signal_connection_failure( host, self._connection.last_error, is_host_addition=False) - return + if is_down: + return # if the connection is not defunct or the host already left, reconnect # manually diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index a4f0ebc4d3..8e800fc2a4 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -15,6 +15,7 @@ import logging import socket +from concurrent.futures import Future from unittest.mock import patch, Mock import uuid @@ -229,6 +230,27 @@ def test_connection_factory_passes_compression_kwarg(self): assert factory.call_args.kwargs['compression'] == expected assert cluster.compression == expected + def test_on_up_without_pool_futures_notifies_listeners(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.set_down() + cluster.metadata.add_or_return_host(host) + + session = Mock() + session.add_or_renew_pool.return_value = None + cluster.sessions.add(session) + + listener = Mock() + cluster.register_listener(listener) + + cluster.on_up(host) + + assert host.is_up is True + listener.on_up.assert_called_once_with(host) + session.update_created_pools.assert_called_once_with() + class SchedulerTest(unittest.TestCase): # TODO: this suite could be expanded; for now just adding a test covering a ticket @@ -339,6 +361,28 @@ def test_set_keyspace_escapes_quotes(self, *_): assert query == 'USE simple_ks', ( "Simple keyspace names should not be quoted, got: %r" % query) + def test_update_created_pools_skips_host_with_node_up_in_progress(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + cluster.metadata.add_or_return_host(host) + cluster.profile_manager.populate(cluster, [host]) + cluster.profile_manager.on_up(host) + + completed = Future() + completed.set_result(True) + + with patch.object(Session, "add_or_renew_pool", return_value=completed) as add_or_renew_pool: + session = Session(cluster, [host]) + add_or_renew_pool.reset_mock() + + session._pools = {} + host._currently_handling_node_up = True + + assert session.update_created_pools() == set() + add_or_renew_pool.assert_not_called() + class ProtocolVersionTests(unittest.TestCase): def test_protocol_downgrade_test(self): diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 037d4a8888..1cc7d7a19c 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -13,15 +13,16 @@ # limitations under the License. import unittest +import uuid from concurrent.futures import ThreadPoolExecutor from unittest.mock import Mock, ANY, call from cassandra import OperationTimedOut, SchemaTargetType, SchemaChangeType from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS -from cassandra.cluster import ControlConnection, _Scheduler, ProfileManager, EXEC_PROFILE_DEFAULT, ExecutionProfile +from cassandra.cluster import Cluster, ControlConnection, _Scheduler, ProfileManager, EXEC_PROFILE_DEFAULT, ExecutionProfile from cassandra.pool import Host -from cassandra.connection import EndPoint, DefaultEndPoint, DefaultEndPointFactory +from cassandra.connection import EndPoint, DefaultEndPoint, DefaultEndPointFactory, ConnectionException from cassandra.policies import (SimpleConvictionPolicy, RoundRobinPolicy, ConstantReconnectionPolicy, IdentityTranslator) @@ -301,6 +302,30 @@ def test_wait_for_schema_agreement_none_timeout(self): cc._time = self.time assert cc.wait_for_schema_agreement() + def test_signal_error_reconnects_when_host_down_signal_is_discounted(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.set_up() + cluster.metadata.add_or_return_host(host) + + session = Mock() + session.get_pool_state.return_value = {host: {"open_count": 1}} + cluster.sessions.add(session) + + connection_error = ConnectionException("control connection failed", endpoint=host.endpoint) + cluster.control_connection._connection = Mock( + endpoint=host.endpoint, + is_defunct=True, + last_error=connection_error) + cluster.control_connection.reconnect = Mock() + + cluster.control_connection._signal_error() + + assert host.is_up is True + cluster.control_connection.reconnect.assert_called_once_with() + def test_refresh_nodes_and_tokens(self): self.control_connection.refresh_node_list_and_token_map() meta = self.cluster.metadata