Skip to content

Commit

Permalink
Invalidate tablets when table or keyspace is deleted
Browse files Browse the repository at this point in the history
Delete tablets for table or keyspace when one is deleted.
When host is removed from cluster delete all tablets that have this host in it.
Ensure that if it happens when control connection is reconnection.
  • Loading branch information
dkropachev committed Jan 4, 2025
1 parent 347f332 commit dcecfbd
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 6 deletions.
1 change: 1 addition & 0 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3641,6 +3641,7 @@ def _set_new_connection(self, conn):
with self._lock:
old = self._connection
self._connection = conn
self.refresh_schema()

if old:
log.debug("[control connection] Closing old connection %r, replacing with %r", old, conn)
Expand Down
13 changes: 12 additions & 1 deletion cassandra/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,14 @@ def _rebuild_all(self, parser):
current_keyspaces = set()
for keyspace_meta in parser.get_all_keyspaces():
current_keyspaces.add(keyspace_meta.name)
old_keyspace_meta = self.keyspaces.get(keyspace_meta.name, None)
old_keyspace_meta: KeyspaceMetadata = self.keyspaces.get(keyspace_meta.name, None)
self.keyspaces[keyspace_meta.name] = keyspace_meta
if old_keyspace_meta:
self._keyspace_updated(keyspace_meta.name)
for table_name, old_table_meta in old_keyspace_meta.tables.items():
new_table_meta: TableMetadata = keyspace_meta.tables.get(table_name, None)
if new_table_meta is None:
self._table_removed(keyspace_meta.name, table_name)
else:
self._keyspace_added(keyspace_meta.name)

Expand Down Expand Up @@ -265,17 +269,22 @@ def _drop_aggregate(self, keyspace, aggregate):
except KeyError:
pass

def _table_removed(self, keyspace, table):
self._tablets.drop_tablet(keyspace, table)

def _keyspace_added(self, ksname):
if self.token_map:
self.token_map.rebuild_keyspace(ksname, build_if_absent=False)

def _keyspace_updated(self, ksname):
if self.token_map:
self.token_map.rebuild_keyspace(ksname, build_if_absent=False)
self._tablets.drop_tablet(ksname)

def _keyspace_removed(self, ksname):
if self.token_map:
self.token_map.remove_keyspace(ksname)
self._tablets.drop_tablet(ksname)

def rebuild_token_map(self, partitioner, token_map):
"""
Expand Down Expand Up @@ -340,11 +349,13 @@ def add_or_return_host(self, host):
return host, True

def remove_host(self, host):
self._tablets.drop_tablet_by_host_id(host.host_id)
with self._hosts_lock:
self._host_id_by_endpoint.pop(host.endpoint, False)
return bool(self._hosts.pop(host.host_id, False))

def remove_host_by_host_id(self, host_id, endpoint=None):
self._tablets.drop_tablet_by_host_id(host_id)
with self._hosts_lock:
if endpoint and self._host_id_by_endpoint[endpoint] == host_id:
self._host_id_by_endpoint.pop(endpoint, False)
Expand Down
39 changes: 39 additions & 0 deletions cassandra/tablets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from audioop import reverse
from threading import Lock
from uuid import UUID


class Tablet(object):
Expand Down Expand Up @@ -32,6 +34,12 @@ def from_row(first_token, last_token, replicas):
return tablet
return None

def replica_contains_host_id(self, uuid: UUID) -> bool:
for replica in self.replicas:
if replica.uuid == uuid:
return True
return False


class Tablets(object):
_lock = None
Expand All @@ -51,6 +59,37 @@ def get_tablet_for_key(self, keyspace, table, t):
return tablet[id]
return None

def drop_tablet(self, keyspace: str, table: str = None):
with self._lock:
if table is not None:
self._tablets.pop((keyspace, table), None)
return

to_be_deleted = []
for key in self._tablets.keys():
if key[0] == keyspace:
to_be_deleted.append(key)

for key in to_be_deleted:
del self._tablets[key]

def drop_tablet_by_host_id(self, host_id: UUID):
if host_id is None:
return
with self._lock:
for key, tablets in self._tablets.keys():
to_be_deleted = []
for tablet_id, tablet in enumerate(tablets):
if tablet.replica_contains_host_id(host_id):
to_be_deleted.append(tablet_id)

if len(to_be_deleted) == 0:
continue

for tablet_id in reverse(to_be_deleted):
tablets.pop(tablet_id)
self._tablets[key] = tablets

def add_tablet(self, keyspace, table, tablet):
with self._lock:
tablets_for_table = self._tablets.setdefault((keyspace, table), [])
Expand Down
53 changes: 48 additions & 5 deletions tests/integration/experiments/test_tablets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def setup_class(cls):
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()),
reconnection_policy=ConstantReconnectionPolicy(1))
cls.session = cls.cluster.connect()
cls.create_ks_and_cf(cls)
cls.create_ks_and_cf(cls.session)
cls.create_data(cls.session)

@classmethod
Expand Down Expand Up @@ -47,6 +47,10 @@ def verify_same_host_in_tracing(self, results):
self.assertEqual(len(host_set), 1)
self.assertIn('locally', "\n".join([event.activity for event in events]))

def get_tablet_record(self, query):
metadata = self.session.cluster.metadata
return metadata._tablets.get_tablet_for_key(query.keyspace, query.table, metadata.token_map.token_class.from_key(query.routing_key))

def verify_same_shard_in_tracing(self, results):
traces = results.get_query_trace()
events = traces.events
Expand All @@ -69,13 +73,14 @@ def verify_same_shard_in_tracing(self, results):
self.assertEqual(len(shard_set), 1)
self.assertIn('locally', "\n".join([event.activity for event in events]))

def create_ks_and_cf(self):
self.session.execute(
@classmethod
def create_ks_and_cf(cls, session):
session.execute(
"""
DROP KEYSPACE IF EXISTS test1
"""
)
self.session.execute(
session.execute(
"""
CREATE KEYSPACE test1
WITH replication = {
Expand All @@ -86,7 +91,7 @@ def create_ks_and_cf(self):
}
""")

self.session.execute(
session.execute(
"""
CREATE TABLE test1.table1 (pk int, ck int, v int, PRIMARY KEY (pk, ck));
""")
Expand All @@ -109,6 +114,8 @@ def query_data_shard_select(self, session, verify_in_tracing=True):
""")

bound = prepared.bind([(2)])
assert self.get_tablet_record(bound) is not None

results = session.execute(bound, trace=True)
self.assertEqual(results, [(2, 2, 0)])
if verify_in_tracing:
Expand All @@ -121,6 +128,8 @@ def query_data_host_select(self, session, verify_in_tracing=True):
""")

bound = prepared.bind([(2)])
assert self.get_tablet_record(bound) is not None

results = session.execute(bound, trace=True)
self.assertEqual(results, [(2, 2, 0)])
if verify_in_tracing:
Expand All @@ -133,6 +142,8 @@ def query_data_shard_insert(self, session, verify_in_tracing=True):
""")

bound = prepared.bind([(51), (1), (2)])
assert self.get_tablet_record(bound) is not None

results = session.execute(bound, trace=True)
if verify_in_tracing:
self.verify_same_shard_in_tracing(results)
Expand All @@ -144,6 +155,8 @@ def query_data_host_insert(self, session, verify_in_tracing=True):
""")

bound = prepared.bind([(52), (1), (2)])
assert self.get_tablet_record(bound) is not None

results = session.execute(bound, trace=True)
if verify_in_tracing:
self.verify_same_host_in_tracing(results)
Expand All @@ -155,3 +168,33 @@ def test_tablets(self):
def test_tablets_shard_awareness(self):
self.query_data_shard_select(self.session)
self.query_data_shard_insert(self.session)

def test_tablets_invalidation_on_ks_deleted(self):
self.run_tablets_invalidation_on_ks_deleted(False)
self.run_tablets_invalidation_on_ks_deleted(True)

def run_tablets_invalidation_on_ks_deleted(self, while_reconnecting: bool):
# Make sure driver holds tablet info
bound = self.session.prepare(
"""
SELECT pk, ck, v FROM test1.table1 WHERE pk = ?
""").bind([(2)])
self.session.execute(bound)
assert self.get_tablet_record(bound) is not None

if while_reconnecting:
conn = self.session.cluster.control_connection._connection
self.session.cluster.control_connection._connection = None
conn.close()

# Drop and recreate ks and table to trigger tablets invalidation
self.create_ks_and_cf(self.cluster.connect())

if while_reconnecting:
self.session.cluster.control_connection._reconnect()
else:
# Wait till driver pick up event and process it
time.sleep(3)

# Check if tablets information was purged
assert self.get_tablet_record(bound) is None

0 comments on commit dcecfbd

Please sign in to comment.