From fbdaa2b35e7960b3009a79d0e2c5d81aff8aa61d Mon Sep 17 00:00:00 2001 From: Bohdan Siryk Date: Thu, 18 Jul 2024 12:33:47 +0300 Subject: [PATCH] Support for Native Protocol 5 release version --- batch_test.go | 82 ++++++++++ cassandra_test.go | 148 +++++++++++++++++ common_test.go | 3 + compressor.go | 6 + conn.go | 192 +++++++++++++++++++--- conn_test.go | 95 +++++++++++ control.go | 2 +- crc.go | 58 +++++++ crc_test.go | 89 +++++++++++ frame.go | 395 ++++++++++++++++++++++++++++++++++++++++------ frame_test.go | 316 +++++++++++++++++++++++++++++++++++++ go.mod | 5 +- go.sum | 23 ++- lz4/lz4.go | 10 ++ prepared_cache.go | 17 ++ session.go | 44 ++++++ 16 files changed, 1417 insertions(+), 68 deletions(-) create mode 100644 crc.go create mode 100644 crc_test.go diff --git a/batch_test.go b/batch_test.go index 25f8c8364..ae4bd8853 100644 --- a/batch_test.go +++ b/batch_test.go @@ -28,6 +28,7 @@ package gocql import ( + "github.com/stretchr/testify/require" "testing" "time" ) @@ -84,3 +85,84 @@ func TestBatch_WithTimestamp(t *testing.T) { t.Errorf("got ts %d, expected %d", storedTs, micros) } } + +func TestBatch_WithNowInSeconds(t *testing.T) { + session := createSession(t) + defer session.Close() + + if session.cfg.ProtoVersion < protoVersion5 { + t.Skip("Batch now in seconds are only available on protocol >= 5") + } + + if err := createTable(session, `CREATE TABLE batch_now_in_seconds (id int primary key, val text)`); err != nil { + t.Fatal(err) + } + + b := session.NewBatch(LoggedBatch) + b.WithNowInSeconds(0) + b.Query("INSERT INTO batch_now_in_seconds (id, val) VALUES (?, ?) USING TTL 20", 1, "val") + if err := session.ExecuteBatch(b); err != nil { + t.Fatal(err) + } + + var remainingTTL int + err := session.Query(`SELECT TTL(val) FROM batch_now_in_seconds WHERE id = ?`, 1). + WithNowInSeconds(10). + Scan(&remainingTTL) + if err != nil { + t.Fatal(err) + } + + require.Equal(t, remainingTTL, 10) +} + +func TestBatch_SetKeyspace(t *testing.T) { + session := createSession(t) + defer session.Close() + + if session.cfg.ProtoVersion < protoVersion5 { + t.Skip("keyspace for BATCH message is not supported in protocol < 5") + } + + const keyspaceStmt = ` + CREATE KEYSPACE IF NOT EXISTS gocql_keyspace_override_test + WITH replication = { + 'class': 'SimpleStrategy', + 'replication_factor': '1' + }; +` + + err := session.Query(keyspaceStmt).Exec() + if err != nil { + t.Fatal(err) + } + + err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_keyspace_override_test.batch_keyspace(id int, value text, PRIMARY KEY (id))") + if err != nil { + t.Fatal(err) + } + + ids := []int{1, 2} + texts := []string{"val1", "val2"} + + b := session.NewBatch(LoggedBatch).SetKeyspace("gocql_keyspace_override_test") + b.Query("INSERT INTO batch_keyspace(id, value) VALUES (?, ?)", ids[0], texts[0]) + b.Query("INSERT INTO batch_keyspace(id, value) VALUES (?, ?)", ids[1], texts[1]) + err = session.ExecuteBatch(b) + if err != nil { + t.Fatal(err) + } + + var ( + id int + text string + ) + + iter := session.Query("SELECT * FROM gocql_keyspace_override_test.batch_keyspace").Iter() + defer iter.Close() + + for i := 0; iter.Scan(&id, &text); i++ { + require.Equal(t, id, ids[i]) + require.Equal(t, text, texts[i]) + } +} diff --git a/cassandra_test.go b/cassandra_test.go index 797a7cf7f..9d95a81f5 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -32,6 +32,7 @@ import ( "context" "errors" "fmt" + "github.com/stretchr/testify/require" "io" "math" "math/big" @@ -3288,3 +3289,150 @@ func TestQuery_NamedValues(t *testing.T) { t.Fatal(err) } } + +func TestQuery_WithNowInSeconds(t *testing.T) { + session := createSession(t) + defer session.Close() + + if session.cfg.ProtoVersion < protoVersion5 { + t.Skip("Query now in seconds are only available on protocol >= 5") + } + + if err := createTable(session, `CREATE TABLE query_now_in_seconds (id int primary key, val text)`); err != nil { + t.Fatal(err) + } + + err := session.Query("INSERT INTO query_now_in_seconds (id, val) VALUES (?, ?) USING TTL 20", 1, "val"). + WithNowInSeconds(int(0)). + Exec() + if err != nil { + t.Fatal(err) + } + + var remainingTTL int + err = session.Query(`SELECT TTL(val) FROM query_now_in_seconds WHERE id = ?`, 1). + WithNowInSeconds(10). + Scan(&remainingTTL) + if err != nil { + t.Fatal(err) + } + + require.Equal(t, remainingTTL, 10) +} + +func TestQuery_SetKeyspace(t *testing.T) { + session := createSession(t) + defer session.Close() + + if session.cfg.ProtoVersion < protoVersion5 { + t.Skip("keyspace for QUERY message is not supported in protocol < 5") + } + + const keyspaceStmt = ` + CREATE KEYSPACE IF NOT EXISTS gocql_query_keyspace_override_test + WITH replication = { + 'class': 'SimpleStrategy', + 'replication_factor': '1' + }; +` + + err := session.Query(keyspaceStmt).Exec() + if err != nil { + t.Fatal(err) + } + + err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_query_keyspace_override_test.query_keyspace(id int, value text, PRIMARY KEY (id))") + if err != nil { + t.Fatal(err) + } + + expectedID := 1 + expectedText := "text" + + // Testing PREPARE message + err = session.Query("INSERT INTO gocql_query_keyspace_override_test.query_keyspace (id, value) VALUES (?, ?)", expectedID, expectedText).Exec() + if err != nil { + t.Fatal(err) + } + + var ( + id int + text string + ) + + q := session.Query("SELECT * FROM gocql_query_keyspace_override_test.query_keyspace"). + SetKeyspace("gocql_query_keyspace_override_test") + err = q.Scan(&id, &text) + if err != nil { + t.Fatal(err) + } + + require.Equal(t, expectedID, id) + require.Equal(t, expectedText, text) + + // Testing QUERY message + id = 0 + text = "" + + q = session.Query("SELECT * FROM gocql_query_keyspace_override_test.query_keyspace"). + SetKeyspace("gocql_query_keyspace_override_test") + q.skipPrepare = true + err = q.Scan(&id, &text) + if err != nil { + t.Fatal(err) + } + + require.Equal(t, expectedID, id) + require.Equal(t, expectedText, text) +} + +func TestLargeSizeQuery(t *testing.T) { + session := createSession(t) + defer session.Close() + + if err := createTable(session, "CREATE TABLE gocql_test.large_size_query(id int, text_col text, PRIMARY KEY (id))"); err != nil { + t.Fatal(err) + } + + defer session.Close() + + longString := strings.Repeat("a", 500_000) + + err := session.Query("INSERT INTO gocql_test.large_size_query (id, text_col) VALUES (?, ?)", "1", longString).Exec() + if err != nil { + t.Fatal(err) + } + + var result string + err = session.Query("SELECT text_col FROM gocql_test.large_size_query").Scan(&result) + if err != nil { + t.Fatal(err) + } + + require.Equal(t, longString, result) +} + +func TestQueryCompressionNotWorthIt(t *testing.T) { + session := createSession(t) + defer session.Close() + + if err := createTable(session, "CREATE TABLE gocql_test.compression_now_worth_it(id int, text_col text, PRIMARY KEY (id))"); err != nil { + t.Fatal(err) + } + + defer session.Close() + + str := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890!@#$%^&*()_+" + err := session.Query("INSERT INTO gocql_test.large_size_query (id, text_col) VALUES (?, ?)", "1", str).Exec() + if err != nil { + t.Fatal(err) + } + + var result string + err = session.Query("SELECT text_col FROM gocql_test.large_size_query").Scan(&result) + if err != nil { + t.Fatal(err) + } + + require.Equal(t, str, result) +} diff --git a/common_test.go b/common_test.go index a5edb03c6..0f3aa00f9 100644 --- a/common_test.go +++ b/common_test.go @@ -27,6 +27,7 @@ package gocql import ( "flag" "fmt" + "github.com/gocql/gocql/lz4" "log" "net" "reflect" @@ -111,6 +112,8 @@ func createCluster(opts ...func(*ClusterConfig)) *ClusterConfig { switch *flagCompressTest { case "snappy": cluster.Compressor = &SnappyCompressor{} + case "lz4": + cluster.Compressor = lz4.LZ4Compressor{} case "": default: panic("invalid compressor: " + *flagCompressTest) diff --git a/compressor.go b/compressor.go index f3d451a9f..f2e00ff96 100644 --- a/compressor.go +++ b/compressor.go @@ -32,6 +32,7 @@ type Compressor interface { Name() string Encode(data []byte) ([]byte, error) Decode(data []byte) ([]byte, error) + DecodeSized(data []byte, size uint32) ([]byte, error) } // SnappyCompressor implements the Compressor interface and can be used to @@ -50,3 +51,8 @@ func (s SnappyCompressor) Encode(data []byte) ([]byte, error) { func (s SnappyCompressor) Decode(data []byte) ([]byte, error) { return snappy.Decode(nil, data) } + +func (s SnappyCompressor) DecodeSized(data []byte, size uint32) ([]byte, error) { + buf := make([]byte, size) + return snappy.Decode(buf, data) +} diff --git a/conn.go b/conn.go index 3daca6250..29b06dcb2 100644 --- a/conn.go +++ b/conn.go @@ -26,6 +26,7 @@ package gocql import ( "bufio" + "bytes" "context" "crypto/tls" "errors" @@ -215,6 +216,14 @@ type Conn struct { host *HostInfo isSchemaV2 bool + // Only for proto v5+. + // Indicates if STARTUP has been completed. + // github.com/apache/cassandra/blob/trunk/doc/native_protocol_v5.spec + // 2.3.1 Initial Handshake + // In order to support both v5 and earlier formats, the v5 framing format is not + // applied to message exchanges before an initial handshake is completed. + startupCompleted bool + session *Session // true if connection close process for the connection started. @@ -474,8 +483,12 @@ func (s *startupCoordinator) startup(ctx context.Context, supported map[string][ case error: return v case *readyFrame: + // Connection is successfully set up and ready to use Native Protocol v5 + s.conn.startupCompleted = true return nil case *authenticateFrame: + // Connection is successfully set up and ready to use Native Protocol v5 + s.conn.startupCompleted = true return s.authenticateHandshake(ctx, v) default: return NewErrProtocol("Unknown type of response to startup frame: %s", v) @@ -593,8 +606,8 @@ func (c *Conn) serve(ctx context.Context) { c.closeWithError(err) } -func (c *Conn) discardFrame(head frameHeader) error { - _, err := io.CopyN(ioutil.Discard, c, int64(head.length)) +func (c *Conn) discardFrame(r io.Reader, head frameHeader) error { + _, err := io.CopyN(ioutil.Discard, r, int64(head.length)) if err != nil { return err } @@ -660,6 +673,16 @@ func (c *Conn) heartBeat(ctx context.Context) { } func (c *Conn) recv(ctx context.Context) error { + // If startup is completed and native proto 5+ is set up then we should + // unwrap payload body from v5 compressed/uncompressed frame + if c.startupCompleted && c.version > protoVersion4 { + return c.recvProtoV5Frame(ctx) + } + + return c.processFrame(ctx, c) +} + +func (c *Conn) processFrame(ctx context.Context, r io.Reader) error { // not safe for concurrent reads // read a full header, ignore timeouts, as this is being ran in a loop @@ -670,7 +693,7 @@ func (c *Conn) recv(ctx context.Context) error { headStartTime := time.Now() // were just reading headers over and over and copy bodies - head, err := readHeader(c.r, c.headerBuf[:]) + head, err := readHeader(r, c.headerBuf[:]) headEndTime := time.Now() if err != nil { return err @@ -694,7 +717,7 @@ func (c *Conn) recv(ctx context.Context) error { } else if head.stream == -1 { // TODO: handle cassandra event frames, we shouldnt get any currently framer := newFramer(c.compressor, c.version) - if err := framer.readFrame(c, &head); err != nil { + if err := framer.readFrame(r, &head); err != nil { return err } go c.session.handleEvent(framer) @@ -727,14 +750,14 @@ func (c *Conn) recv(ctx context.Context) error { c.mu.Unlock() if call == nil || !ok { c.logger.Printf("gocql: received response for stream which has no handler: header=%v\n", head) - return c.discardFrame(head) + return c.discardFrame(r, head) } else if head.stream != call.streamID { panic(fmt.Sprintf("call has incorrect streamID: got %d expected %d", call.streamID, head.stream)) } framer := newFramer(c.compressor, c.version) - err = framer.readFrame(c, &head) + err = framer.readFrame(r, &head) if err != nil { // only net errors should cause the connection to be closed. Though // cassandra returning corrupt frames will be returned here as well. @@ -777,6 +800,47 @@ func (c *Conn) handleTimeout() { } } +func (c *Conn) recvProtoV5Frame(ctx context.Context) error { + var ( + payload []byte + isSelfContained bool + err error + ) + + // Read frame based on compression + if c.compressor != nil { + payload, isSelfContained, err = readCompressedFrame(c.r, c.compressor) + } else { + payload, isSelfContained, err = readUncompressedFrame(c.r) + } + if err != nil { + return err + } + + if isSelfContained { + return c.processAllEnvelopesInFrame(ctx, bytes.NewReader(payload)) + } + + head, err := readHeader(bytes.NewReader(payload), c.headerBuf[:]) + if err != nil { + return err + } + + const envelopeHeaderLength = 9 + buf := bytes.NewBuffer(make([]byte, 0, head.length+envelopeHeaderLength)) + buf.Write(payload) + + // Computing how many bytes of message left to read + bytesToRead := head.length - len(payload) + envelopeHeaderLength + + err = c.recvLastsFrames(buf, bytesToRead) + if err != nil { + return err + } + + return c.processFrame(ctx, buf) +} + type callReq struct { // resp will receive the frame that was sent as a response to this stream. resp chan callResp @@ -1086,7 +1150,29 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram return nil, err } - n, err := c.w.writeContext(ctx, framer.buf) + var n int + + if c.version > protoVersion4 && c.startupCompleted { + err = framer.prepareModernLayout() + if err != nil { + // closeWithError will block waiting for this stream to either receive a response + // or for us to timeout. + close(call.timeout) + // We failed to serialize the frame into a buffer. + // This should not affect the connection as we didn't write anything. We just free the current call. + c.mu.Lock() + if !c.closed { + delete(c.calls, call.streamID) + } + c.mu.Unlock() + // We need to release the stream after we remove the call from c.calls, otherwise the existingCall != nil + // check above could fail. + c.releaseStream(call) + return nil, err + } + } + + n, err = c.w.writeContext(ctx, framer.buf) if err != nil { // closeWithError will block waiting for this stream to either receive a response // or for us to timeout, close the timeout chan here. Im not entirely sure @@ -1223,9 +1309,10 @@ type StreamObserverContext interface { } type preparedStatment struct { - id []byte - request preparedMetadata - response resultMetadata + id []byte + resultMetadataID []byte + request preparedMetadata + response resultMetadata } type inflightPrepare struct { @@ -1235,7 +1322,7 @@ type inflightPrepare struct { preparedStatment *preparedStatment } -func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) (*preparedStatment, error) { +func (c *Conn) prepareStatementForKeyspace(ctx context.Context, stmt string, tracer Tracer, keyspace string) (*preparedStatment, error) { stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, stmt) flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare { flight := &inflightPrepare{ @@ -1253,7 +1340,7 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) statement: stmt, } if c.version > protoVersion4 { - prep.keyspace = c.currentKeyspace + prep.keyspace = keyspace } // we won the race to do the load, if our context is canceled we shouldnt @@ -1284,7 +1371,8 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) flight.preparedStatment = &preparedStatment{ // defensively copy as we will recycle the underlying buffer after we // return. - id: copyBytes(x.preparedID), + id: copyBytes(x.preparedID), + resultMetadataID: copyBytes(x.resultMetadataID), // the type info's should _not_ have a reference to the framers read buffer, // therefore we can just copy them directly. request: x.reqMeta, @@ -1310,6 +1398,10 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) } } +func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) (*preparedStatment, error) { + return c.prepareStatementForKeyspace(ctx, stmt, tracer, c.currentKeyspace) +} + func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error { if named, ok := value.(*namedValue); ok { dst.name = named.name @@ -1347,7 +1439,9 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { params.pageSize = qry.pageSize } if c.version > protoVersion4 { - params.keyspace = c.currentKeyspace + params.keyspace = qry.keyspace + params.useNowInSeconds = qry.useNowInSeconds + params.nowInSecondsValue = qry.nowInSecondsValue } var ( @@ -1358,7 +1452,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { if !qry.skipPrepare && qry.shouldPrepare() { // Prepare all DML queries. Other queries can not be prepared. var err error - info, err = c.prepareStatement(ctx, qry.stmt, qry.trace) + info, err = c.prepareStatementForKeyspace(ctx, qry.stmt, qry.trace, qry.keyspace) if err != nil { return &Iter{err: err} } @@ -1394,9 +1488,10 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { params.skipMeta = !(c.session.cfg.DisableSkipMetadata || qry.disableSkipMetadata) frame = &writeExecuteFrame{ - preparedID: info.id, - params: params, - customPayload: qry.customPayload, + preparedID: info.id, + params: params, + customPayload: qry.customPayload, + resultMetadataID: info.resultMetadataID, } // Set "keyspace" and "table" property in the query if it is present in preparedMetadata @@ -1430,6 +1525,24 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { case *resultVoidFrame: return &Iter{framer: framer} case *resultRowsFrame: + if x.meta.newMetadataID != nil { + stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, qry.stmt) + inflight, ok := c.session.stmtsLRU.get(stmtCacheKey) + if !ok { + // We didn't find the stmt in the cache, so we just re-prepare it + return c.executeQuery(ctx, qry) + } + + // Updating the result metadata id in prepared stmt + // + // If a RESULT/Rows message reports + // changed resultset metadata with the Metadata_changed flag, the reported new + // resultset metadata must be used in subsequent executions + inflight.preparedStatment.resultMetadataID = x.meta.newMetadataID + inflight.preparedStatment.response = x.meta + return c.executeQuery(ctx, qry) + } + iter := &Iter{ meta: x.meta, framer: framer, @@ -1554,6 +1667,12 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { customPayload: batch.CustomPayload, } + if c.version > protoVersion4 { + req.keyspace = batch.keyspace + req.useNowInSeconds = batch.useNowInSeconds + req.nowInSecondsValue = batch.nowInSecondsValue + } + stmts := make(map[string]string, len(batch.Entries)) for i := 0; i < n; i++ { @@ -1561,7 +1680,7 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { b := &req.statements[i] if len(entry.Args) > 0 || entry.binding != nil { - info, err := c.prepareStatement(batch.Context(), entry.Stmt, batch.trace) + info, err := c.prepareStatementForKeyspace(batch.Context(), entry.Stmt, batch.trace, batch.keyspace) if err != nil { return &Iter{err: err} } @@ -1756,6 +1875,41 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) { return fmt.Errorf("gocql: cluster schema versions not consistent: %+v", schemas) } +// recvLastsFrames reads proto v5 frames from Conn.r and writes decoded payload to dst. +// It reads data until the bytesToRead is reached. +// If Conn.compressor is not nil, it processes Compressed Format frames. +func (c *Conn) recvLastsFrames(dst *bytes.Buffer, bytesToRead int) error { + var read int + var segment []byte + var err error + for read != bytesToRead { + // Read frame based on compression + if c.compressor != nil { + segment, _, err = readCompressedFrame(c.r, c.compressor) + } else { + segment, _, err = readUncompressedFrame(c.r) + } + if err != nil { + return fmt.Errorf("gocql: failed to read non self-contained frame: %w", err) + } + + // Write the segment to the destination writer + n, _ := dst.Write(segment) + read += n + } + + return nil +} + +func (c *Conn) processAllEnvelopesInFrame(ctx context.Context, r *bytes.Reader) error { + var err error + for r.Len() > 0 && err == nil { + err = c.processFrame(ctx, r) + } + + return err +} + var ( ErrQueryArgLength = errors.New("gocql: query argument length mismatch") ErrTimeoutNoResponse = errors.New("gocql: no response received from cassandra within timeout period") diff --git a/conn_test.go b/conn_test.go index cab4c2f8f..9df3ae8da 100644 --- a/conn_test.go +++ b/conn_test.go @@ -47,6 +47,7 @@ import ( "time" "github.com/gocql/gocql/internal/streams" + "github.com/stretchr/testify/require" ) const ( @@ -1300,3 +1301,97 @@ func (srv *TestServer) readFrame(conn net.Conn) (*framer, error) { return framer, nil } + +func TestConnProcessAllEnvelopesInSingleFrame(t *testing.T) { + server, client, err := tcpConnPair() + require.NoError(t, err) + + c := &Conn{ + conn: server, + r: bufio.NewReader(server), + calls: make(map[int]*callReq), + version: protoVersion5, + addr: server.RemoteAddr().String(), + streams: streams.New(protoVersion5), + isSchemaV2: true, + startupCompleted: true, + w: &deadlineContextWriter{ + w: server, + timeout: time.Second * 10, + semaphore: make(chan struct{}, 1), + quit: make(chan struct{}), + }, + logger: Logger, + writeTimeout: time.Second * 10, + } + + call1 := &callReq{ + timeout: make(chan struct{}), + streamID: 1, + resp: make(chan callResp), + } + + call2 := &callReq{ + timeout: make(chan struct{}), + streamID: 2, + resp: make(chan callResp), + } + + c.calls[1] = call1 + c.calls[2] = call2 + + req := writeQueryFrame{ + statement: "SELECT * FROM system.local", + params: queryParams{ + consistency: Quorum, + keyspace: "gocql_test", + }, + } + + framer1 := newFramer(nil, protoVersion5) + err = req.buildFrame(framer1, 1) + require.NoError(t, err) + + framer2 := newFramer(nil, protoVersion5) + err = req.buildFrame(framer2, 2) + require.NoError(t, err) + + go func() { + var buf []byte + buf = append(buf, framer1.buf...) + buf = append(buf, framer2.buf...) + + uncompressedFrame, err := newUncompressedFrame(buf, true) + require.NoError(t, err) + + _, err = client.Write(uncompressedFrame) + require.NoError(t, err) + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Hour) + defer cancel() + + errCh := make(chan error, 1) + go func() { + errCh <- c.recvProtoV5Frame(ctx) + }() + + go func() { + resp1 := <-call1.resp + close(call1.timeout) + // Skipping here the header of the envelope because resp.framer contains already parsed header + // and resp.framer.buf contains envelope body + require.Equal(t, framer1.buf[9:], resp1.framer.buf) + + resp2 := <-call2.resp + close(call2.timeout) + require.Equal(t, framer2.buf[9:], resp2.framer.buf) + }() + + select { + case <-ctx.Done(): + t.Fatal("Timed out waiting for frames") + case err := <-errCh: + require.NoError(t, err) + } +} diff --git a/control.go b/control.go index b30b44ea3..a2ce62a5f 100644 --- a/control.go +++ b/control.go @@ -216,7 +216,7 @@ func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) { hosts = shuffleHosts(hosts) connCfg := *c.session.connCfg - connCfg.ProtoVersion = 4 // TODO: define maxProtocol + connCfg.ProtoVersion = 5 // TODO: define maxProtocol handler := connErrorHandlerFn(func(c *Conn, err error, closed bool) { // we should never get here, but if we do it means we connected to a diff --git a/crc.go b/crc.go new file mode 100644 index 000000000..874a68508 --- /dev/null +++ b/crc.go @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package gocql + +import ( + "hash/crc32" +) + +var ( + // Initial CRC32 bytes: 0xFA, 0x2D, 0x55, 0xCA + initialCRC32Bytes = []byte{0xfa, 0x2d, 0x55, 0xca} +) + +// ChecksumIEEE calculates the CRC32 checksum of the given byte slice. +func ChecksumIEEE(b []byte) uint32 { + crc := crc32.NewIEEE() + crc.Write(initialCRC32Bytes) // Include initial CRC32 bytes + crc.Write(b) + return crc.Sum32() +} + +const ( + crc24Init = 0x875060 // Initial value for CRC24 calculation + crc24Poly = 0x1974F0B // Polynomial for CRC24 calculation +) + +// KoopmanChecksum calculates the CRC24 checksum using the Koopman polynomial. +func KoopmanChecksum(buf []byte) uint32 { + crc := crc24Init + for _, b := range buf { + crc ^= int(b) << 16 + + for i := 0; i < 8; i++ { + crc <<= 1 + if crc&0x1000000 != 0 { + crc ^= crc24Poly + } + } + } + + return uint32(crc) +} diff --git a/crc_test.go b/crc_test.go new file mode 100644 index 000000000..e2b44b3f8 --- /dev/null +++ b/crc_test.go @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package gocql + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestChecksumIEEE(t *testing.T) { + tests := []struct { + name string + buf []byte + expected uint32 + }{ + // expected values are manually generated using crc24 impl in Cassandra + { + name: "empty buf", + buf: []byte{}, + expected: 1148681939, + }, + { + name: "buf filled with 0", + buf: []byte{0, 0, 0, 0, 0}, + expected: 1178391023, + }, + { + name: "buf filled with some data", + buf: []byte{1, 2, 3, 4, 5, 6}, + expected: 3536190002, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, ChecksumIEEE(tt.buf)) + }) + } +} + +func TestKoopmanChecksum(t *testing.T) { + tests := []struct { + name string + buf []byte + expected uint32 + }{ + // expected values are manually generated using crc32 impl in Cassandra + { + name: "buf filled with 0 (len 3)", + buf: []byte{0, 0, 0}, + expected: 8251255, + }, + { + name: "buf filled with 0 (len 5)", + buf: []byte{0, 0, 0, 0, 0}, + expected: 11185162, + }, + { + name: "buf filled with some data (len 3)", + buf: []byte{64, -30 & 0xff, 1}, + expected: 5891942, + }, + { + name: "buf filled with some data (len 5)", + buf: []byte{64, -30 & 0xff, 1, 0, 0}, + expected: 8775784, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, KoopmanChecksum(tt.buf)) + }) + } +} diff --git a/frame.go b/frame.go index d374ae574..def218de0 100644 --- a/frame.go +++ b/frame.go @@ -25,7 +25,9 @@ package gocql import ( + "bytes" "context" + "encoding/binary" "errors" "fmt" "io" @@ -70,6 +72,8 @@ const ( protoVersion5 = 0x05 maxFrameSize = 256 * 1024 * 1024 + + maxPayloadSize = 1<<17 - 1 ) type protoVersion byte @@ -168,16 +172,18 @@ const ( flagGlobalTableSpec int = 0x01 flagHasMorePages int = 0x02 flagNoMetaData int = 0x04 + flagMetaDataChanged int = 0x08 // query flags - flagValues byte = 0x01 - flagSkipMetaData byte = 0x02 - flagPageSize byte = 0x04 - flagWithPagingState byte = 0x08 - flagWithSerialConsistency byte = 0x10 - flagDefaultTimestamp byte = 0x20 - flagWithNameValues byte = 0x40 - flagWithKeyspace byte = 0x80 + flagValues uint32 = 0x01 + flagSkipMetaData uint32 = 0x02 + flagPageSize uint32 = 0x04 + flagWithPagingState uint32 = 0x08 + flagWithSerialConsistency uint32 = 0x10 + flagDefaultTimestamp uint32 = 0x20 + flagWithNameValues uint32 = 0x40 + flagWithKeyspace uint32 = 0x80 + flagWithNowInSeconds uint32 = 0x100 // prepare flags flagWithPreparedKeyspace uint32 = 0x01 @@ -524,7 +530,7 @@ func (f *framer) readFrame(r io.Reader, head *frameHeader) error { return fmt.Errorf("unable to read frame body: read %d/%d bytes: %v", n, head.length, err) } - if head.flags&flagCompress == flagCompress { + if f.proto < protoVersion5 && head.flags&flagCompress == flagCompress { if f.compres == nil { return NewErrProtocol("no compressor available with compressed frame body") } @@ -768,7 +774,7 @@ func (f *framer) finish() error { return ErrFrameTooBig } - if f.buf[1]&flagCompress == flagCompress { + if f.proto < protoVersion5 && f.buf[1]&flagCompress == flagCompress { if f.compres == nil { panic("compress flag set with no compressor") } @@ -1017,6 +1023,8 @@ type resultMetadata struct { // it is at minimum len(columns) but may be larger, for instance when a column // is a UDT or tuple. actualColCount int + + newMetadataID []byte } func (r *resultMetadata) morePages() bool { @@ -1060,6 +1068,10 @@ func (f *framer) parseResultMetadata() resultMetadata { meta.pagingState = copyBytes(f.readBytes()) } + if meta.flags&flagMetaDataChanged == flagMetaDataChanged { + meta.newMetadataID = copyBytes(f.readBytes()) + } + if meta.flags&flagNoMetaData == flagNoMetaData { return meta } @@ -1164,18 +1176,24 @@ func (f *framer) parseResultSetKeyspace() frame { type resultPreparedFrame struct { frameHeader - preparedID []byte - reqMeta preparedMetadata - respMeta resultMetadata + preparedID []byte + resultMetadataID []byte + reqMeta preparedMetadata + respMeta resultMetadata } func (f *framer) parseResultPrepared() frame { frame := &resultPreparedFrame{ frameHeader: *f.header, preparedID: f.readShortBytes(), - reqMeta: f.parsePreparedMetadata(), } + if f.proto > protoVersion4 { + frame.resultMetadataID = copyBytes(f.readShortBytes()) + } + + frame.reqMeta = f.parsePreparedMetadata() + if f.proto < protoVersion2 { return frame } @@ -1457,12 +1475,14 @@ type queryParams struct { defaultTimestamp bool defaultTimestampValue int64 // v5+ - keyspace string + keyspace string + useNowInSeconds bool + nowInSecondsValue int } func (q queryParams) String() string { - return fmt.Sprintf("[query_params consistency=%v skip_meta=%v page_size=%d paging_state=%q serial_consistency=%v default_timestamp=%v values=%v keyspace=%s]", - q.consistency, q.skipMeta, q.pageSize, q.pagingState, q.serialConsistency, q.defaultTimestamp, q.values, q.keyspace) + return fmt.Sprintf("[query_params consistency=%v skip_meta=%v page_size=%d paging_state=%q serial_consistency=%v default_timestamp=%v values=%v keyspace=%s now_in_seconds=%v]", + q.consistency, q.skipMeta, q.pageSize, q.pagingState, q.serialConsistency, q.defaultTimestamp, q.values, q.keyspace, q.useNowInSeconds) } func (f *framer) writeQueryParams(opts *queryParams) { @@ -1472,7 +1492,9 @@ func (f *framer) writeQueryParams(opts *queryParams) { return } - var flags byte + var flags uint32 + names := false + if len(opts.values) > 0 { flags |= flagValues } @@ -1489,8 +1511,6 @@ func (f *framer) writeQueryParams(opts *queryParams) { flags |= flagWithSerialConsistency } - names := false - // protoV3 specific things if f.proto > protoVersion2 { if opts.defaultTimestamp { @@ -1503,18 +1523,21 @@ func (f *framer) writeQueryParams(opts *queryParams) { } } - if opts.keyspace != "" { - if f.proto > protoVersion4 { + // protoV5 specific things + if f.proto > protoVersion4 { + if opts.keyspace != "" { flags |= flagWithKeyspace - } else { - panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher")) + } + + if opts.useNowInSeconds { + flags |= flagWithNowInSeconds } } if f.proto > protoVersion4 { - f.writeUint(uint32(flags)) + f.writeUint(flags) } else { - f.writeByte(flags) + f.writeByte(byte(flags)) } if n := len(opts.values); n > 0 { @@ -1555,8 +1578,14 @@ func (f *framer) writeQueryParams(opts *queryParams) { f.writeLong(ts) } - if opts.keyspace != "" { - f.writeString(opts.keyspace) + if f.proto > protoVersion4 { + if opts.keyspace != "" { + f.writeString(opts.keyspace) + } + + if opts.useNowInSeconds { + f.writeInt(int32(opts.nowInSecondsValue)) + } } } @@ -1604,6 +1633,9 @@ type writeExecuteFrame struct { // v4+ customPayload map[string][]byte + + // v5+ + resultMetadataID []byte } func (e *writeExecuteFrame) String() string { @@ -1611,16 +1643,21 @@ func (e *writeExecuteFrame) String() string { } func (e *writeExecuteFrame) buildFrame(fr *framer, streamID int) error { - return fr.writeExecuteFrame(streamID, e.preparedID, &e.params, &e.customPayload) + return fr.writeExecuteFrame(streamID, e.preparedID, e.resultMetadataID, &e.params, &e.customPayload) } -func (f *framer) writeExecuteFrame(streamID int, preparedID []byte, params *queryParams, customPayload *map[string][]byte) error { +func (f *framer) writeExecuteFrame(streamID int, preparedID, resultMetadataID []byte, params *queryParams, customPayload *map[string][]byte) error { if len(*customPayload) > 0 { f.payload() } f.writeHeader(f.flags, opExecute, streamID) f.writeCustomPayload(customPayload) f.writeShortBytes(preparedID) + + if f.proto > protoVersion4 { + f.writeShortBytes(resultMetadataID) + } + if f.proto > protoVersion1 { f.writeQueryParams(params) } else { @@ -1659,6 +1696,11 @@ type writeBatchFrame struct { //v4+ customPayload map[string][]byte + + //v5+ + keyspace string + useNowInSeconds bool + nowInSecondsValue int } func (w *writeBatchFrame) buildFrame(framer *framer, streamID int) error { @@ -1676,7 +1718,7 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload n := len(w.statements) f.writeShort(uint16(n)) - var flags byte + var flags uint32 for i := 0; i < n; i++ { b := &w.statements[i] @@ -1717,25 +1759,46 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload if w.defaultTimestamp { flags |= flagDefaultTimestamp } + } - if f.proto > protoVersion4 { - f.writeUint(uint32(flags)) + if f.proto > protoVersion4 { + if w.keyspace != "" { + flags |= flagWithKeyspace + } + + if w.useNowInSeconds { + flags |= flagWithNowInSeconds + } + + } + + if f.proto > protoVersion4 { + f.writeUint(flags) + } else { + f.writeByte(byte(flags)) + } + + if w.serialConsistency > 0 { + f.writeConsistency(Consistency(w.serialConsistency)) + } + + if w.defaultTimestamp { + var ts int64 + if w.defaultTimestampValue != 0 { + ts = w.defaultTimestampValue } else { - f.writeByte(flags) + ts = time.Now().UnixNano() / 1000 } + f.writeLong(ts) + } - if w.serialConsistency > 0 { - f.writeConsistency(Consistency(w.serialConsistency)) + if f.proto > protoVersion4 { + if w.keyspace != "" { + f.writeString(w.keyspace) } - if w.defaultTimestamp { - var ts int64 - if w.defaultTimestampValue != 0 { - ts = w.defaultTimestampValue - } else { - ts = time.Now().UnixNano() / 1000 - } - f.writeLong(ts) + if w.useNowInSeconds { + f.writeInt(int32(w.nowInSecondsValue)) } } @@ -2070,3 +2133,247 @@ func (f *framer) writeBytesMap(m map[string][]byte) { f.writeBytes(v) } } + +func (f *framer) prepareModernLayout() error { + // Ensure protocol version is V5 or higher + if f.proto < protoVersion5 { + panic("Modern layout is not supported with version V4 or less") + } + + selfContained := true + + var ( + adjustedBuf []byte + tempBuf []byte + err error + ) + + // Process the buffer in chunks if it exceeds the max payload size + for len(f.buf) > maxPayloadSize { + if f.compres != nil { + tempBuf, err = newCompressedFrame(f.buf[:maxPayloadSize], false, f.compres) + } else { + tempBuf, err = newUncompressedFrame(f.buf[:maxPayloadSize], false) + } + if err != nil { + return err + } + + adjustedBuf = append(adjustedBuf, tempBuf...) + f.buf = f.buf[maxPayloadSize:] + selfContained = false + } + + // Process the remaining buffer + if f.compres != nil { + tempBuf, err = newCompressedFrame(f.buf, selfContained, f.compres) + } else { + tempBuf, err = newUncompressedFrame(f.buf, selfContained) + } + if err != nil { + return err + } + + adjustedBuf = append(adjustedBuf, tempBuf...) + f.buf = adjustedBuf + + return nil +} + +func readUncompressedFrame(r io.Reader) ([]byte, bool, error) { + const headerSize = 6 + header := [headerSize + 1]byte{} + + // Read the frame header + if _, err := io.ReadFull(r, header[:headerSize]); err != nil { + return nil, false, fmt.Errorf("gocql: failed to read uncompressed frame, err: %w", err) + } + + // Compute and verify the header CRC24 + computedHeaderCRC24 := KoopmanChecksum(header[:3]) + readHeaderCRC24 := binary.LittleEndian.Uint32(header[3:]) & 0xFFFFFF + if computedHeaderCRC24 != readHeaderCRC24 { + return nil, false, fmt.Errorf("gocql: crc24 mismatch in frame header, computed: %d, got: %d", computedHeaderCRC24, readHeaderCRC24) + } + + // Extract the payload length and self-contained flag + headerInt := binary.LittleEndian.Uint32(header[:4]) + payloadLen := int(headerInt & 0x1FFFF) + isSelfContained := (headerInt & (1 << 17)) != 0 + + // Read the payload + payload := make([]byte, payloadLen) + if _, err := io.ReadFull(r, payload); err != nil { + return nil, false, fmt.Errorf("gocql: failed to read uncompressed frame payload, err: %w", err) + } + + // Read and verify the payload CRC32 + if _, err := io.ReadFull(r, header[:4]); err != nil { + return nil, false, fmt.Errorf("gocql: failed to read payload crc32, err: %w", err) + } + + computedPayloadCRC32 := ChecksumIEEE(payload) + readPayloadCRC32 := binary.LittleEndian.Uint32(header[:4]) + if computedPayloadCRC32 != readPayloadCRC32 { + return nil, false, fmt.Errorf("gocql: payload crc32 mismatch, computed: %d, got: %d", computedPayloadCRC32, readPayloadCRC32) + } + + return payload, isSelfContained, nil +} + +func newUncompressedFrame(payload []byte, isSelfContained bool) ([]byte, error) { + const ( + headerSize = 6 + selfContainedBit = 1 << 17 + ) + + payloadLen := len(payload) + if payloadLen > maxPayloadSize { + return nil, fmt.Errorf("payload length (%d) exceeds maximum size of 128 KiB", payloadLen) + } + + // Create the frame + frameSize := headerSize + payloadLen + 4 // 4 bytes for CRC32 + frame := make([]byte, frameSize) + + // First 3 bytes: payload length and self-contained flag + headerInt := uint32(payloadLen) & 0x1FFFF + if isSelfContained { + headerInt |= selfContainedBit // Set the self-contained flag + } + + // Encode the first 3 bytes as a single little-endian integer + frame[0] = byte(headerInt) + frame[1] = byte(headerInt >> 8) + frame[2] = byte(headerInt >> 16) + + // Calculate CRC24 for the first 3 bytes of the header + crc := KoopmanChecksum(frame[:3]) + + // Encode CRC24 into the next 3 bytes of the header + frame[3] = byte(crc) + frame[4] = byte(crc >> 8) + frame[5] = byte(crc >> 16) + + copy(frame[headerSize:], payload) // Copy the payload to the frame + + // Calculate CRC32 for the payload + payloadCRC32 := ChecksumIEEE(payload) + binary.LittleEndian.PutUint32(frame[headerSize+payloadLen:], payloadCRC32) + + return frame, nil +} + +func newCompressedFrame(uncompressedPayload []byte, isSelfContained bool, compressor Compressor) ([]byte, error) { + uncompressedLen := len(uncompressedPayload) + if uncompressedLen > maxPayloadSize { + return nil, fmt.Errorf("uncompressed compressed payload length exceedes max size of frame payload %d/%d", uncompressedLen, maxPayloadSize) + } + + compressedPayload, err := compressor.Encode(uncompressedPayload) + if err != nil { + return nil, err + } + + // Skip the first 4 bytes because the size of the uncompressed payload is written in the frame header, not in the + // body of the compressed envelope + compressedPayload = compressedPayload[4:] + + compressedLen := len(compressedPayload) + + // Compression is not worth it + if uncompressedLen < compressedLen { + // native_protocol_v5.spec + // 2.2 + // An uncompressed length of 0 signals that the compressed payload + // should be used as-is and not decompressed. + compressedPayload = uncompressedPayload + compressedLen = uncompressedLen + uncompressedLen = 0 + } + + // Combine compressed and uncompressed lengths and set the self-contained flag if needed + combined := uint64(compressedLen) | uint64(uncompressedLen)<<17 + if isSelfContained { + combined |= 1 << 34 + } + + var headerBuf [8]byte + + // Write the combined value into the header buffer + binary.LittleEndian.PutUint64(headerBuf[:], combined) + + // Create a buffer with enough capacity to hold the header, compressed payload, and checksums + buf := bytes.NewBuffer(make([]byte, 0, 8+compressedLen+4)) + + // Write the first 5 bytes of the header (compressed and uncompressed sizes) + buf.Write(headerBuf[:5]) + + // Compute and write the CRC24 checksum of the first 5 bytes + headerChecksum := KoopmanChecksum(headerBuf[:5]) + binary.LittleEndian.PutUint32(headerBuf[:], headerChecksum) + buf.Write(headerBuf[:3]) + buf.Write(compressedPayload) + + // Compute and write the CRC32 checksum of the payload + payloadChecksum := ChecksumIEEE(compressedPayload) + binary.LittleEndian.PutUint32(headerBuf[:], payloadChecksum) + buf.Write(headerBuf[:4]) + + return buf.Bytes(), nil +} + +func readCompressedFrame(r io.Reader, compressor Compressor) ([]byte, bool, error) { + var ( + headerBuf [8]byte + err error + ) + + if _, err = io.ReadFull(r, headerBuf[:]); err != nil { + return nil, false, err + } + + // Reading checksum from frame header + readHeaderChecksum := uint32(headerBuf[5]) | uint32(headerBuf[6])<<8 | uint32(headerBuf[7])<<16 + if computedHeaderChecksum := KoopmanChecksum(headerBuf[:5]); computedHeaderChecksum != readHeaderChecksum { + return nil, false, fmt.Errorf("gocql: crc24 mismatch in frame header, read: %d, computed: %d", readHeaderChecksum, computedHeaderChecksum) + } + + // First 17 bits - payload size after compression + compressedLen := uint32(headerBuf[0]) | uint32(headerBuf[1])<<8 | uint32(headerBuf[2]&0x1)<<16 + + // The next 17 bits - payload size before compression + uncompressedLen := (uint32(headerBuf[2]) >> 1) | uint32(headerBuf[3])<<7 | uint32(headerBuf[4]&0b11)<<15 + + // Self-contained flag + selfContained := (headerBuf[4] & 0b100) != 0 + + compressedPayload := make([]byte, compressedLen) + if _, err = io.ReadFull(r, compressedPayload); err != nil { + return nil, false, fmt.Errorf("gocql: failed to read compressed frame payload, err: %w", err) + } + + if _, err = io.ReadFull(r, headerBuf[:4]); err != nil { + return nil, false, fmt.Errorf("gocql: failed to read payload crc32, err: %w", err) + } + + // Ensuring if payload checksum matches + readPayloadChecksum := binary.LittleEndian.Uint32(headerBuf[:4]) + if computedPayloadChecksum := ChecksumIEEE(compressedPayload); readPayloadChecksum != computedPayloadChecksum { + return nil, false, fmt.Errorf("gocql: crc32 mismatch in payload, read: %d, computed: %d", readPayloadChecksum, computedPayloadChecksum) + } + + var uncompressedPayload []byte + if uncompressedLen > 0 { + if uncompressedPayload, err = compressor.DecodeSized(compressedPayload, uncompressedLen); err != nil { + return nil, false, err + } + if uint32(len(uncompressedPayload)) != uncompressedLen { + return nil, false, fmt.Errorf("gocql: length mismatch after payload decoding, got %d, expected %d", len(uncompressedPayload), uncompressedLen) + } + } else { + uncompressedPayload = compressedPayload + } + + return uncompressedPayload, selfContained, nil +} diff --git a/frame_test.go b/frame_test.go index 170cba710..27c96581b 100644 --- a/frame_test.go +++ b/frame_test.go @@ -26,6 +26,10 @@ package gocql import ( "bytes" + "encoding/binary" + "errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "os" "testing" ) @@ -127,3 +131,315 @@ func TestFrameReadTooLong(t *testing.T) { t.Fatalf("expected to get header %v got %v", opReady, head.op) } } + +func Test_framer_writeExecuteFrame(t *testing.T) { + framer := newFramer(nil, protoVersion5) + frame := writeExecuteFrame{ + preparedID: []byte{1, 2, 3}, + resultMetadataID: []byte{4, 5, 6}, + customPayload: map[string][]byte{ + "key1": []byte("value1"), + }, + params: queryParams{ + useNowInSeconds: true, + nowInSecondsValue: 123, + keyspace: "test_keyspace", + }, + } + + err := framer.writeExecuteFrame(123, frame.preparedID, frame.resultMetadataID, &frame.params, &frame.customPayload) + if err != nil { + t.Fatal(err) + } + + // skipping header + framer.buf = framer.buf[9:] + + assertDeepEqual(t, "customPayload", frame.customPayload, framer.readBytesMap()) + assertDeepEqual(t, "preparedID", frame.preparedID, framer.readShortBytes()) + assertDeepEqual(t, "resultMetadataID", frame.resultMetadataID, framer.readShortBytes()) + assertDeepEqual(t, "constistency", frame.params.consistency, Consistency(framer.readShort())) + + flags := framer.readInt() + if flags&int(flagWithNowInSeconds) != int(flagWithNowInSeconds) { + t.Fatal("expected flagNowInSeconds to be set, but it is not") + } + + if flags&int(flagWithKeyspace) != int(flagWithKeyspace) { + t.Fatal("expected flagWithKeyspace to be set, but it is not") + } + + assertDeepEqual(t, "keyspace", frame.params.keyspace, framer.readString()) + assertDeepEqual(t, "useNowInSeconds", frame.params.nowInSecondsValue, framer.readInt()) +} + +func Test_framer_writeBatchFrame(t *testing.T) { + framer := newFramer(nil, protoVersion5) + frame := writeBatchFrame{ + customPayload: map[string][]byte{ + "key1": []byte("value1"), + }, + useNowInSeconds: true, + nowInSecondsValue: 123, + } + + err := framer.writeBatchFrame(123, &frame, frame.customPayload) + if err != nil { + t.Fatal(err) + } + + // skipping header + framer.buf = framer.buf[9:] + + assertDeepEqual(t, "customPayload", frame.customPayload, framer.readBytesMap()) + assertDeepEqual(t, "typ", frame.typ, BatchType(framer.readByte())) + assertDeepEqual(t, "len(statements)", len(frame.statements), int(framer.readShort())) + assertDeepEqual(t, "consistency", frame.consistency, Consistency(framer.readShort())) + + flags := framer.readInt() + if flags&int(flagWithNowInSeconds) != int(flagWithNowInSeconds) { + t.Fatal("expected flagNowInSeconds to be set, but it is not") + } + + assertDeepEqual(t, "useNowInSeconds", frame.nowInSecondsValue, framer.readInt()) +} + +type testMockedCompressor struct { + // this is an error its methods should return + expectedError error + + // invalidateDecodedDataLength allows to simulate data decoding invalidation + invalidateDecodedDataLength bool +} + +func (m testMockedCompressor) Name() string { + return "testMockedCompressor" +} + +func (m testMockedCompressor) Encode(data []byte) ([]byte, error) { + encoded := make([]byte, len(data)+4) + binary.BigEndian.PutUint32(encoded, uint32(len(data))) + copy(encoded[4:], data) + if m.expectedError != nil { + return nil, m.expectedError + } + return encoded, nil +} + +func (m testMockedCompressor) Decode(data []byte) ([]byte, error) { + if m.expectedError != nil { + return nil, m.expectedError + } + return data, nil +} + +func (m testMockedCompressor) DecodeSized(data []byte, size uint32) ([]byte, error) { + if m.expectedError != nil { + return nil, m.expectedError + } + + // simulating invalid size of decoded data + if m.invalidateDecodedDataLength { + return data[:size-1], nil + } + + return data, nil +} + +func Test_readUncompressedFrame(t *testing.T) { + tests := []struct { + name string + modifyFrame func([]byte) []byte + expectedErr string + }{ + { + name: "header crc24 mismatch", + modifyFrame: func(frame []byte) []byte { + // simulating some crc invalidation + frame[0] = 255 + return frame + }, + expectedErr: "gocql: crc24 mismatch in frame header", + }, + { + name: "body crc32 mismatch", + modifyFrame: func(frame []byte) []byte { + // simulating body crc32 mismatch + frame[len(frame)-1] = 255 + return frame + }, + expectedErr: "gocql: payload crc32 mismatch", + }, + { + name: "invalid frame length", + modifyFrame: func(frame []byte) []byte { + // simulating body length invalidation + frame = frame[:7] + return frame + }, + expectedErr: "gocql: failed to read uncompressed frame payload", + }, + { + name: "cannot read body checksum", + modifyFrame: func(frame []byte) []byte { + // simulating body length invalidation + frame = frame[:len(frame)-4] + return frame + }, + expectedErr: "gocql: failed to read payload crc32", + }, + { + name: "success", + modifyFrame: nil, + expectedErr: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + framer := newFramer(nil, protoVersion5) + req := writeQueryFrame{ + statement: "SELECT * FROM system.local", + params: queryParams{ + consistency: Quorum, + keyspace: "gocql_test", + }, + } + + err := req.buildFrame(framer, 128) + require.NoError(t, err) + + frame, err := newUncompressedFrame(framer.buf, true) + require.NoError(t, err) + + if tt.modifyFrame != nil { + frame = tt.modifyFrame(frame) + } + + readFrame, isSelfContained, err := readUncompressedFrame(bytes.NewReader(frame)) + + if tt.expectedErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.expectedErr) + } else { + require.NoError(t, err) + assert.True(t, isSelfContained) + assert.Equal(t, framer.buf, readFrame) + } + }) + } +} + +func Test_readCompressedFrame(t *testing.T) { + tests := []struct { + name string + // modifyFrameFn is useful for simulating frame data invalidation + modifyFrameFn func([]byte) []byte + compressor testMockedCompressor + + // expectedErrorMsg is an error message that should be returned by Error() method. + // We need this to understand which of fmt.Errorf() is returned + expectedErrorMsg string + }{ + { + name: "header crc24 mismatch", + modifyFrameFn: func(frame []byte) []byte { + // simulating some crc invalidation + frame[0] = 255 + return frame + }, + expectedErrorMsg: "gocql: crc24 mismatch in frame header", + }, + { + name: "body crc32 mismatch", + modifyFrameFn: func(frame []byte) []byte { + // simulating body crc32 mismatch + frame[len(frame)-1] = 255 + return frame + }, + expectedErrorMsg: "gocql: crc32 mismatch in payload", + }, + { + name: "invalid frame length", + modifyFrameFn: func(frame []byte) []byte { + // simulating body length invalidation + return frame[:12] + }, + expectedErrorMsg: "gocql: failed to read compressed frame payload", + }, + { + name: "cannot read body checksum", + modifyFrameFn: func(frame []byte) []byte { + // simulating body length invalidation + return frame[:len(frame)-4] + }, + expectedErrorMsg: "gocql: failed to read payload crc32", + }, + { + name: "failed to encode payload", + modifyFrameFn: nil, + compressor: testMockedCompressor{ + expectedError: errors.New("failed to encode payload"), + }, + expectedErrorMsg: "failed to encode payload", + }, + { + name: "failed to decode payload", + modifyFrameFn: nil, + compressor: testMockedCompressor{ + expectedError: errors.New("failed to decode payload"), + }, + expectedErrorMsg: "failed to decode payload", + }, + { + name: "length mismatch after decoding", + modifyFrameFn: nil, + compressor: testMockedCompressor{ + invalidateDecodedDataLength: true, + }, + expectedErrorMsg: "gocql: length mismatch after payload decoding", + }, + { + name: "success", + modifyFrameFn: nil, + expectedErrorMsg: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + framer := newFramer(nil, protoVersion5) + req := writeQueryFrame{ + statement: "SELECT * FROM system.local", + params: queryParams{ + consistency: Quorum, + keyspace: "gocql_test", + }, + } + + err := req.buildFrame(framer, 128) + require.NoError(t, err) + + frame, err := newCompressedFrame(framer.buf, true, testMockedCompressor{}) + require.NoError(t, err) + + if tt.modifyFrameFn != nil { + frame = tt.modifyFrameFn(frame) + } + + readFrame, selfContained, err := readCompressedFrame(bytes.NewReader(frame), tt.compressor) + + switch { + case tt.expectedErrorMsg != "": + require.Error(t, err) + require.Contains(t, err.Error(), tt.expectedErrorMsg) + case tt.compressor.expectedError != nil: + require.ErrorIs(t, err, tt.compressor.expectedError) + default: + require.NoError(t, err) + assert.True(t, selfContained) + assert.Equal(t, framer.buf, readFrame) + } + }) + } +} diff --git a/go.mod b/go.mod index 0aea881ec..d90f9d9fe 100644 --- a/go.mod +++ b/go.mod @@ -20,11 +20,14 @@ module github.com/gocql/gocql require ( github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 // indirect github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect + github.com/gocql/gocql/lz4 v0.0.0-20240925165811-953e0df999ca github.com/golang/snappy v0.0.3 github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed github.com/kr/pretty v0.1.0 // indirect - github.com/stretchr/testify v1.3.0 // indirect + github.com/stretchr/testify v1.9.0 gopkg.in/inf.v0 v0.9.1 ) go 1.13 + +replace github.com/gocql/gocql/lz4 => ./lz4 diff --git a/go.sum b/go.sum index 2e3892bcb..67083bcb9 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,11 @@ github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYE github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gocql/gocql/lz4 v0.0.0-20240925165811-953e0df999ca h1:3kQmFuM8n9V4qKNr1HaPJ6ZNTp/9uKPoqccE17fTIiQ= +github.com/gocql/gocql/lz4 v0.0.0-20240925165811-953e0df999ca/go.mod h1:81C8vezOSDewKcugB728RMD1LXojw/fHx1No4w7Jmm4= github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= @@ -13,10 +16,24 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/pierrec/lz4/v4 v4.1.8 h1:ieHkV+i2BRzngO4Wd/3HGowuZStgq6QkPsD1eolNAO4= +github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/lz4/lz4.go b/lz4/lz4.go index 049fdc0bb..eb8329ee6 100644 --- a/lz4/lz4.go +++ b/lz4/lz4.go @@ -73,3 +73,13 @@ func (s LZ4Compressor) Decode(data []byte) ([]byte, error) { n, err := lz4.UncompressBlock(data[4:], buf) return buf[:n], err } + +func (s LZ4Compressor) DecodeSized(data []byte, size uint32) ([]byte, error) { + buf := make([]byte, size) + _, err := lz4.UncompressBlock(data, buf) + if err != nil { + return nil, err + } + + return buf, nil +} diff --git a/prepared_cache.go b/prepared_cache.go index 3fd256d33..7f5533a2d 100644 --- a/prepared_cache.go +++ b/prepared_cache.go @@ -100,3 +100,20 @@ func (p *preparedLRU) evictPreparedID(key string, id []byte) { } } + +func (p *preparedLRU) get(key string) (*inflightPrepare, bool) { + p.mu.Lock() + defer p.mu.Unlock() + + val, ok := p.lru.Get(key) + if !ok { + return nil, false + } + + ifp, ok := val.(*inflightPrepare) + if !ok { + return nil, false + } + + return ifp, true +} diff --git a/session.go b/session.go index a600b95f3..98fb7cfd7 100644 --- a/session.go +++ b/session.go @@ -936,6 +936,10 @@ type Query struct { // routingInfo is a pointer because Query can be copied and copyable struct can't hold a mutex. routingInfo *queryRoutingInfo + + keyspace string + useNowInSeconds bool + nowInSecondsValue int } type queryRoutingInfo struct { @@ -1423,6 +1427,25 @@ func (q *Query) releaseAfterExecution() { q.decRefCount() } +// SetKeyspace will enable keyspace flag on the query. +// It allows to specify the keyspace that the query should be executed in +// +// Only available on protocol >= 5. +func (q *Query) SetKeyspace(keyspace string) *Query { + q.keyspace = keyspace + return q +} + +// WithNowInSeconds will enable the with now_in_seconds flag on the query. +// Also, it allows to define now_in_seconds value. +// +// Only available on protocol >= 5. +func (q *Query) WithNowInSeconds(now int) *Query { + q.useNowInSeconds = true + q.nowInSecondsValue = now + return q +} + // Iter represents an iterator that can be used to iterate over all rows that // were returned by a query. The iterator might send additional queries to the // database during the iteration if paging was enabled. @@ -1742,6 +1765,8 @@ type Batch struct { cancelBatch func() keyspace string metrics *queryMetrics + useNowInSeconds bool + nowInSecondsValue int // routingInfo is a pointer because Query can be copied and copyable struct can't hold a mutex. routingInfo *queryRoutingInfo @@ -2042,6 +2067,25 @@ func (b *Batch) releaseAfterExecution() { // that would race with speculative executions. } +// SetKeyspace will enable keyspace flag on the query. +// It allows to specify the keyspace that the query should be executed in +// +// Only available on protocol >= 5. +func (b *Batch) SetKeyspace(keyspace string) *Batch { + b.keyspace = keyspace + return b +} + +// WithNowInSeconds will enable the with now_in_seconds flag on the query. +// Also, it allows to define now_in_seconds value. +// +// Only available on protocol >= 5. +func (b *Batch) WithNowInSeconds(now int) *Batch { + b.useNowInSeconds = true + b.nowInSecondsValue = now + return b +} + type BatchType byte const (