From 0298a009c3d555eb7e36bb40ebbd2fb67a7fb07e Mon Sep 17 00:00:00 2001 From: Bohdan Siryk Date: Wed, 30 Oct 2024 18:01:09 +0200 Subject: [PATCH] 1. Updated the way how the driver constructs stmt cache keys. The current code base uses initial keyspace provided by the user to construct the keys. Since proto v5 we also should account for keyspace bounding for a specific query, so the driver should use the bounded keyspace instead of the initial to construct the key. 2. Changed the way how routing key cache keys are constructed to account the keyspace overriding as well. --- cassandra_test.go | 140 ++++++++++++++++++++++++++++++++++++++++++++-- conn.go | 33 +++++++---- session.go | 30 +++++++--- 3 files changed, 180 insertions(+), 23 deletions(-) diff --git a/cassandra_test.go b/cassandra_test.go index 773bb288c..d0e919746 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -1483,7 +1483,7 @@ func TestQueryInfo(t *testing.T) { defer session.Close() conn := getRandomConn(t, session) - info, err := conn.prepareStatement(context.Background(), "SELECT release_version, host_id FROM system.local WHERE key = ?", nil) + info, err := conn.prepareStatement(context.Background(), "SELECT release_version, host_id FROM system.local WHERE key = ?", nil, conn.currentKeyspace) if err != nil { t.Fatalf("Failed to execute query for preparing statement: %v", err) @@ -2602,7 +2602,7 @@ func TestRoutingKey(t *testing.T) { t.Fatalf("failed to create table with error '%v'", err) } - routingKeyInfo, err := session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?") + routingKeyInfo, err := session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", "") if err != nil { t.Fatalf("failed to get routing key info due to error: %v", err) } @@ -2626,7 +2626,7 @@ func TestRoutingKey(t *testing.T) { } // verify the cache is working - routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?") + routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", "") if err != nil { t.Fatalf("failed to get routing key info due to error: %v", err) } @@ -2660,7 +2660,7 @@ func TestRoutingKey(t *testing.T) { t.Errorf("Expected routing key %v but was %v", expectedRoutingKey, routingKey) } - routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?") + routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?", "") if err != nil { t.Fatalf("failed to get routing key info due to error: %v", err) } @@ -3606,3 +3606,135 @@ func TestPrepareExecuteMetadataChangedFlag(t *testing.T) { require.Equal(t, preparedStatementAfterTableAltering2.resultMetadataID, preparedStatementAfterTableAltering3.resultMetadataID) require.Equal(t, preparedStatementAfterTableAltering2.response, preparedStatementAfterTableAltering3.response) } + +func TestStmtCacheUsesOverriddenKeyspace(t *testing.T) { + session := createSession(t) + defer session.Close() + + const createKeyspaceStmt = `CREATE KEYSPACE IF NOT EXISTS %s + WITH replication = { + 'class' : 'SimpleStrategy', + 'replication_factor' : 1 + }` + + err := createTable(session, fmt.Sprintf(createKeyspaceStmt, "gocql_test_stmt_cache")) + if err != nil { + t.Fatal(err) + } + + err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test.stmt_cache_uses_overridden_ks(id int, PRIMARY KEY (id))") + if err != nil { + t.Fatal(err) + } + + err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test_stmt_cache.stmt_cache_uses_overridden_ks(id int, PRIMARY KEY (id))") + if err != nil { + t.Fatal(err) + } + + const insertQuery = "INSERT INTO stmt_cache_uses_overridden_ks (id) VALUES (?)" + + // Inserting data via Batch to ensure that batches + // properly accounts for keyspace overriding + b1 := session.NewBatch(LoggedBatch) + b1.Query(insertQuery, 1) + err = session.ExecuteBatch(b1) + require.NoError(t, err) + + b2 := session.NewBatch(LoggedBatch) + b2.SetKeyspace("gocql_test_stmt_cache") + b2.Query(insertQuery, 2) + err = session.ExecuteBatch(b2) + require.NoError(t, err) + + var scannedID int + + const selectStmt = "SELECT * FROM stmt_cache_uses_overridden_ks" + + // By default in our test suite session uses gocql_test ks + err = session.Query(selectStmt).Scan(&scannedID) + require.NoError(t, err) + require.Equal(t, 1, scannedID) + + scannedID = 0 + err = session.Query(selectStmt).SetKeyspace("gocql_test_stmt_cache").Scan(&scannedID) + require.NoError(t, err) + require.Equal(t, 2, scannedID) + + session.Query("DROP KEYSPACE IF EXISTS gocql_test_stmt_cache").Exec() +} + +func TestRoutingKeyCacheUsesOverriddenKeyspace(t *testing.T) { + session := createSession(t) + defer session.Close() + + const createKeyspaceStmt = `CREATE KEYSPACE IF NOT EXISTS %s + WITH replication = { + 'class' : 'SimpleStrategy', + 'replication_factor' : 1 + }` + + err := createTable(session, fmt.Sprintf(createKeyspaceStmt, "gocql_test_routing_key_cache")) + if err != nil { + t.Fatal(err) + } + + err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test.routing_key_cache_uses_overridden_ks(id int, PRIMARY KEY (id))") + if err != nil { + t.Fatal(err) + } + + err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test_routing_key_cache.routing_key_cache_uses_overridden_ks(id int, PRIMARY KEY (id))") + if err != nil { + t.Fatal(err) + } + + getRoutingKeyInfo := func(key string) *routingKeyInfo { + t.Helper() + session.routingKeyInfoCache.mu.Lock() + value, _ := session.routingKeyInfoCache.lru.Get(key) + session.routingKeyInfoCache.mu.Unlock() + + inflight := value.(*inflightCachedEntry) + return inflight.value.(*routingKeyInfo) + } + + const insertQuery = "INSERT INTO routing_key_cache_uses_overridden_ks (id) VALUES (?)" + + // Running batch in default ks gocql_test + b1 := session.NewBatch(LoggedBatch) + b1.Query(insertQuery, 1) + _, err = b1.GetRoutingKey() + require.NoError(t, err) + + // Ensuring that the cache contains the query with default ks + routingKeyInfo1 := getRoutingKeyInfo("gocql_test" + b1.Entries[0].Stmt) + require.Equal(t, "gocql_test", routingKeyInfo1.keyspace) + + // Running batch in gocql_test_routing_key_cache ks + b2 := session.NewBatch(LoggedBatch) + b2.SetKeyspace("gocql_test_routing_key_cache") + b2.Query(insertQuery, 2) + _, err = b2.GetRoutingKey() + require.NoError(t, err) + + // Ensuring that the cache contains the query with gocql_test_routing_key_cache ks + routingKeyInfo2 := getRoutingKeyInfo("gocql_test_routing_key_cache" + b2.Entries[0].Stmt) + require.Equal(t, "gocql_test_routing_key_cache", routingKeyInfo2.keyspace) + + const selectStmt = "SELECT * FROM routing_key_cache_uses_overridden_ks WHERE id=?" + + // Running query in default ks gocql_test + q1 := session.Query(selectStmt, 1) + _, err = q1.GetRoutingKey() + require.NoError(t, err) + require.Equal(t, "gocql_test", q1.routingInfo.keyspace) + + // Running query in gocql_test_routing_key_cache ks + q2 := session.Query(selectStmt, 1) + _, err = q2.SetKeyspace("gocql_test_routing_key_cache").GetRoutingKey() + require.NoError(t, err) + require.Equal(t, "gocql_test_routing_key_cache", q2.routingInfo.keyspace) + + session.Query("DROP KEYSPACE IF EXISTS gocql_test_routing_key_cache").Exec() +} diff --git a/conn.go b/conn.go index 7c2728ba0..1fd3ea3cf 100644 --- a/conn.go +++ b/conn.go @@ -1410,8 +1410,8 @@ type inflightPrepare struct { preparedStatment *preparedStatment } -func (c *Conn) prepareStatementForKeyspace(ctx context.Context, stmt string, tracer Tracer, keyspace string) (*preparedStatment, error) { - stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, stmt) +func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer, keyspace string) (*preparedStatment, error) { + stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), keyspace, stmt) flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare { flight := &inflightPrepare{ done: make(chan struct{}), @@ -1486,10 +1486,6 @@ func (c *Conn) prepareStatementForKeyspace(ctx context.Context, stmt string, tra } } -func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) (*preparedStatment, error) { - return c.prepareStatementForKeyspace(ctx, stmt, tracer, c.currentKeyspace) -} - func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error { if named, ok := value.(*namedValue); ok { dst.name = named.name @@ -1531,6 +1527,13 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { params.nowInSeconds = qry.nowInSecondsValue } + // If a keyspace for the qry is overriden, + // then we should use it to create stmt cache key + usedKeyspace := c.currentKeyspace + if qry.keyspace != "" { + usedKeyspace = qry.keyspace + } + var ( frame frameBuilder info *preparedStatment @@ -1539,7 +1542,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { if !qry.skipPrepare && qry.shouldPrepare() { // Prepare all DML queries. Other queries can not be prepared. var err error - info, err = c.prepareStatementForKeyspace(ctx, qry.stmt, qry.trace, qry.keyspace) + info, err = c.prepareStatement(ctx, qry.stmt, qry.trace, usedKeyspace) if err != nil { return &Iter{err: err} } @@ -1584,6 +1587,9 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { // Set "keyspace" and "table" property in the query if it is present in preparedMetadata qry.routingInfo.mu.Lock() qry.routingInfo.keyspace = info.request.keyspace + if info.request.keyspace == "" { + qry.routingInfo.keyspace = usedKeyspace + } qry.routingInfo.table = info.request.table qry.routingInfo.mu.Unlock() } else { @@ -1616,7 +1622,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { // If a RESULT/Rows message reports // changed resultset metadata with the Metadata_changed flag, the reported new // resultset metadata must be used in subsequent executions - stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, qry.stmt) + stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, qry.stmt) oldInflight, ok := c.session.stmtsLRU.get(stmtCacheKey) if ok { newInflight := &inflightPrepare{ @@ -1685,7 +1691,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { // is not consistent with regards to its schema. return iter case *RequestErrUnprepared: - stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, qry.stmt) + stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, qry.stmt) c.session.stmtsLRU.evictPreparedID(stmtCacheKey, x.StatementId) return c.executeQuery(ctx, qry) case error: @@ -1767,6 +1773,11 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { req.nowInSeconds = batch.nowInSeconds } + usedKeyspace := c.currentKeyspace + if batch.keyspace != "" { + usedKeyspace = batch.keyspace + } + stmts := make(map[string]string, len(batch.Entries)) for i := 0; i < n; i++ { @@ -1774,7 +1785,7 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { b := &req.statements[i] if len(entry.Args) > 0 || entry.binding != nil { - info, err := c.prepareStatementForKeyspace(batch.Context(), entry.Stmt, batch.trace, batch.keyspace) + info, err := c.prepareStatement(batch.Context(), entry.Stmt, batch.trace, usedKeyspace) if err != nil { return &Iter{err: err} } @@ -1836,7 +1847,7 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { case *RequestErrUnprepared: stmt, found := stmts[string(x.StatementId)] if found { - key := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, stmt) + key := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, stmt) c.session.stmtsLRU.evictPreparedID(key, x.StatementId) } return c.executeBatch(ctx, batch) diff --git a/session.go b/session.go index 2175c28e2..774b2cc09 100644 --- a/session.go +++ b/session.go @@ -591,11 +591,20 @@ func (s *Session) getConn() *Conn { return nil } -// returns routing key indexes and type info -func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyInfo, error) { +// Returns routing key indexes and type info. +// If keyspace == "" it uses the keyspace which is specified in Cluster.Keyspace +func (s *Session) routingKeyInfo(ctx context.Context, stmt string, keyspace string) (*routingKeyInfo, error) { + if keyspace == "" { + keyspace = s.cfg.Keyspace + } + + routingKeyInfoCacheKey := keyspace + stmt + s.routingKeyInfoCache.mu.Lock() - entry, cached := s.routingKeyInfoCache.lru.Get(stmt) + // Using here keyspace + stmt as a cache key because + // the query keyspace could be overridden via SetKeyspace + entry, cached := s.routingKeyInfoCache.lru.Get(routingKeyInfoCacheKey) if cached { // done accessing the cache s.routingKeyInfoCache.mu.Unlock() @@ -619,7 +628,7 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI inflight := new(inflightCachedEntry) inflight.wg.Add(1) defer inflight.wg.Done() - s.routingKeyInfoCache.lru.Add(stmt, inflight) + s.routingKeyInfoCache.lru.Add(routingKeyInfoCacheKey, inflight) s.routingKeyInfoCache.mu.Unlock() var ( @@ -635,7 +644,7 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI } // get the query info for the statement - info, inflight.err = conn.prepareStatement(ctx, stmt, nil) + info, inflight.err = conn.prepareStatement(ctx, stmt, nil, keyspace) if inflight.err != nil { // don't cache this error s.routingKeyInfoCache.Remove(stmt) @@ -651,7 +660,9 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI } table := info.request.table - keyspace := info.request.keyspace + if info.request.keyspace != "" { + keyspace = info.request.keyspace + } if len(info.request.pkeyColumns) > 0 { // proto v4 dont need to calculate primary key columns @@ -1146,6 +1157,9 @@ func (q *Query) Keyspace() string { if q.routingInfo.keyspace != "" { return q.routingInfo.keyspace } + if q.keyspace != "" { + return q.keyspace + } if q.session == nil { return "" @@ -1177,7 +1191,7 @@ func (q *Query) GetRoutingKey() ([]byte, error) { } // try to determine the routing key - routingKeyInfo, err := q.session.routingKeyInfo(q.Context(), q.stmt) + routingKeyInfo, err := q.session.routingKeyInfo(q.Context(), q.stmt, q.keyspace) if err != nil { return nil, err } @@ -2009,7 +2023,7 @@ func (b *Batch) GetRoutingKey() ([]byte, error) { return nil, nil } // try to determine the routing key - routingKeyInfo, err := b.session.routingKeyInfo(b.Context(), entry.Stmt) + routingKeyInfo, err := b.session.routingKeyInfo(b.Context(), entry.Stmt, b.keyspace) if err != nil { return nil, err }