Skip to content

Commit

Permalink
Make blocking set keyspace query to fail by timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
dkropachev committed Aug 9, 2024
1 parent 7e0b02d commit eaacfb9
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 8 deletions.
2 changes: 1 addition & 1 deletion cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2388,7 +2388,7 @@ def _prepare_all_queries(self, host):
else:
for keyspace, ks_statements in groupby(statements, lambda s: s.keyspace):
if keyspace is not None:
connection.set_keyspace_blocking(keyspace)
connection.set_keyspace_blocking(keyspace, self.control_connection_timeout)

# prepare 10 statements at a time
ks_statements = list(ks_statements)
Expand Down
4 changes: 2 additions & 2 deletions cassandra/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,14 +1498,14 @@ def _handle_auth_response(self, auth_response):
log.error(msg, self.endpoint, auth_response)
raise ProtocolError(msg % (self.endpoint, auth_response))

def set_keyspace_blocking(self, keyspace):
def set_keyspace_blocking(self, keyspace, timeout=None):
if not keyspace or keyspace == self.keyspace:
return

query = QueryMessage(query='USE "%s"' % (keyspace,),
consistency_level=ConsistencyLevel.ONE)
try:
result = self.wait_for_response(query)
result = self.wait_for_response(query, timeout=timeout)
except InvalidRequestException as ire:
# the keyspace probably doesn't exist
raise ire.to_exception()
Expand Down
10 changes: 5 additions & 5 deletions cassandra/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def __init__(self, host, host_distance, session):
self._keyspace = session.keyspace

if self._keyspace:
first_connection.set_keyspace_blocking(self._keyspace)
first_connection.set_keyspace_blocking(self._keyspace, session.cluster.control_connection_timeout)
if first_connection.features.sharding_info and not self._session.cluster.shard_aware_options.disable:
self.host.sharding_info = first_connection.features.sharding_info
self._open_connections_for_all_shards(first_connection.features.shard_id)
Expand Down Expand Up @@ -615,7 +615,7 @@ def _replace(self, connection):
connection = self._session.cluster.connection_factory(self.host.endpoint,
on_orphaned_stream_released=self.on_orphaned_stream_released)
if self._keyspace:
connection.set_keyspace_blocking(self._keyspace)
connection.set_keyspace_blocking(self._keyspace, self._session.cluster.control_connection_timeout)
self._connections[connection.features.shard_id] = connection
except Exception:
log.warning("Failed reconnecting %s. Retrying." % (self.host.endpoint,))
Expand Down Expand Up @@ -766,7 +766,7 @@ def _open_connection_to_missing_shard(self, shard_id):
self.host
)
if self._keyspace:
conn.set_keyspace_blocking(self._keyspace)
conn.set_keyspace_blocking(self._keyspace, self._session.cluster.control_connection_timeout)

self._connections[conn.features.shard_id] = conn
if old_conn is not None:
Expand Down Expand Up @@ -953,7 +953,7 @@ def __init__(self, host, host_distance, session):
self._keyspace = session.keyspace
if self._keyspace:
for conn in self._connections:
conn.set_keyspace_blocking(self._keyspace)
conn.set_keyspace_blocking(self._keyspace, self._session.cluster.control_connection_timeout)

self._trash = set()
self._next_trash_allowed_at = time.time()
Expand Down Expand Up @@ -1053,7 +1053,7 @@ def _add_conn_if_under_max(self):
try:
conn = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released)
if self._keyspace:
conn.set_keyspace_blocking(self._session.keyspace)
conn.set_keyspace_blocking(self._session.keyspace, self._session.cluster.control_connection_timeout)
self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL
with self._lock:
new_connections = self._connections[:] + [conn]
Expand Down
37 changes: 37 additions & 0 deletions tests/integration/standard/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
RetryPolicy, SimpleConvictionPolicy, HostDistance,
AddressTranslator, TokenAwarePolicy, HostFilterPolicy)
from cassandra import ConsistencyLevel
from cassandra.protocol import ProtocolHandler, QueryMessage

from cassandra.query import SimpleStatement, TraceUnavailable, tuple_factory
from cassandra.auth import PlainTextAuthProvider, SaslAuthProvider
Expand Down Expand Up @@ -484,6 +485,42 @@ def test_refresh_schema_table(self):
self.assertEqual(original_system_schema_meta.as_cql_query(), current_system_schema_meta.as_cql_query())
cluster.shutdown()

def test_use_keyspace_blocking(self):
ks = "test_refresh_schema_type"

cluster = TestCluster()
send_msg_orig = cluster.connection_class.send_msg

def send_msg_patched(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message,
decoder=ProtocolHandler.decode_message, result_metadata=None):
if isinstance(msg, QueryMessage) and f'USE "{ks}"' in msg.query:
orig_decoder = decoder
def decode_patched(protocol_version, protocol_features, user_type_map, stream_id, flags, opcode, body,
decompressor, result_metadata):
time.sleep(cluster.control_connection_timeout + 0.1)
return orig_decoder(protocol_version, protocol_features, user_type_map, stream_id, flags,
opcode, body, decompressor, result_metadata)

decoder = decode_patched

return send_msg_orig(self, msg, request_id, cb, encoder, decoder, result_metadata)

cluster.connection_class.send_msg = send_msg_patched

cluster.connect().execute("""
CREATE KEYSPACE IF NOT EXISTS %s
WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' }
""" % ks)

try:
cluster.connect(ks)
except NoHostAvailable:
pass
except Exception as e:
self.fail(f"got unexpected exception {e}")
else:
self.fail("connection should fail, but was not")

def test_refresh_schema_type(self):
if get_server_versions()[0] < (2, 1, 0):
raise unittest.SkipTest('UDTs were introduced in Cassandra 2.1')
Expand Down

0 comments on commit eaacfb9

Please sign in to comment.