diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 06e6293ef..b032149a7 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2388,7 +2388,7 @@ def _prepare_all_queries(self, host): else: for keyspace, ks_statements in groupby(statements, lambda s: s.keyspace): if keyspace is not None: - connection.set_keyspace_blocking(keyspace) + connection.set_keyspace_blocking(keyspace, self.control_connection_timeout) # prepare 10 statements at a time ks_statements = list(ks_statements) diff --git a/cassandra/connection.py b/cassandra/connection.py index ebdfe9999..722566c26 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -1498,14 +1498,14 @@ def _handle_auth_response(self, auth_response): log.error(msg, self.endpoint, auth_response) raise ProtocolError(msg % (self.endpoint, auth_response)) - def set_keyspace_blocking(self, keyspace): + def set_keyspace_blocking(self, keyspace, timeout=None): if not keyspace or keyspace == self.keyspace: return query = QueryMessage(query='USE "%s"' % (keyspace,), consistency_level=ConsistencyLevel.ONE) try: - result = self.wait_for_response(query) + result = self.wait_for_response(query, timeout=timeout) except InvalidRequestException as ire: # the keyspace probably doesn't exist raise ire.to_exception() diff --git a/cassandra/pool.py b/cassandra/pool.py index 738fc8e6d..593909751 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -435,7 +435,7 @@ def __init__(self, host, host_distance, session): self._keyspace = session.keyspace if self._keyspace: - first_connection.set_keyspace_blocking(self._keyspace) + first_connection.set_keyspace_blocking(self._keyspace, session.cluster.control_connection_timeout) if first_connection.features.sharding_info and not self._session.cluster.shard_aware_options.disable: self.host.sharding_info = first_connection.features.sharding_info self._open_connections_for_all_shards(first_connection.features.shard_id) @@ -615,7 +615,7 @@ def _replace(self, connection): connection = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released) if self._keyspace: - connection.set_keyspace_blocking(self._keyspace) + connection.set_keyspace_blocking(self._keyspace, self._session.cluster.control_connection_timeout) self._connections[connection.features.shard_id] = connection except Exception: log.warning("Failed reconnecting %s. Retrying." % (self.host.endpoint,)) @@ -766,7 +766,7 @@ def _open_connection_to_missing_shard(self, shard_id): self.host ) if self._keyspace: - conn.set_keyspace_blocking(self._keyspace) + conn.set_keyspace_blocking(self._keyspace, self._session.cluster.control_connection_timeout) self._connections[conn.features.shard_id] = conn if old_conn is not None: @@ -953,7 +953,7 @@ def __init__(self, host, host_distance, session): self._keyspace = session.keyspace if self._keyspace: for conn in self._connections: - conn.set_keyspace_blocking(self._keyspace) + conn.set_keyspace_blocking(self._keyspace, self._session.cluster.control_connection_timeout) self._trash = set() self._next_trash_allowed_at = time.time() @@ -1053,7 +1053,7 @@ def _add_conn_if_under_max(self): try: conn = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released) if self._keyspace: - conn.set_keyspace_blocking(self._session.keyspace) + conn.set_keyspace_blocking(self._session.keyspace, self._session.cluster.control_connection_timeout) self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL with self._lock: new_connections = self._connections[:] + [conn] diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index 43356dbd8..d43e8808e 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -32,6 +32,7 @@ RetryPolicy, SimpleConvictionPolicy, HostDistance, AddressTranslator, TokenAwarePolicy, HostFilterPolicy) from cassandra import ConsistencyLevel +from cassandra.protocol import ProtocolHandler, QueryMessage from cassandra.query import SimpleStatement, TraceUnavailable, tuple_factory from cassandra.auth import PlainTextAuthProvider, SaslAuthProvider @@ -484,6 +485,42 @@ def test_refresh_schema_table(self): self.assertEqual(original_system_schema_meta.as_cql_query(), current_system_schema_meta.as_cql_query()) cluster.shutdown() + def test_use_keyspace_blocking(self): + ks = "test_refresh_schema_type" + + cluster = TestCluster() + send_msg_orig = cluster.connection_class.send_msg + + def send_msg_patched(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, + decoder=ProtocolHandler.decode_message, result_metadata=None): + if isinstance(msg, QueryMessage) and f'USE "{ks}"' in msg.query: + orig_decoder = decoder + def decode_patched(protocol_version, protocol_features, user_type_map, stream_id, flags, opcode, body, + decompressor, result_metadata): + time.sleep(cluster.control_connection_timeout + 0.1) + return orig_decoder(protocol_version, protocol_features, user_type_map, stream_id, flags, + opcode, body, decompressor, result_metadata) + + decoder = decode_patched + + return send_msg_orig(self, msg, request_id, cb, encoder, decoder, result_metadata) + + cluster.connection_class.send_msg = send_msg_patched + + cluster.connect().execute(""" + CREATE KEYSPACE IF NOT EXISTS %s + WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' } + """ % ks) + + try: + cluster.connect(ks) + except NoHostAvailable: + pass + except Exception as e: + self.fail(f"got unexpected exception {e}") + else: + self.fail("connection should fail, but was not") + def test_refresh_schema_type(self): if get_server_versions()[0] < (2, 1, 0): raise unittest.SkipTest('UDTs were introduced in Cassandra 2.1')