Skip to content

Commit

Permalink
Support for Native Protocol 5 release version
Browse files Browse the repository at this point in the history
  • Loading branch information
worryg0d committed Sep 24, 2024
1 parent 974fa12 commit 47a3ae9
Show file tree
Hide file tree
Showing 16 changed files with 1,424 additions and 46 deletions.
82 changes: 82 additions & 0 deletions batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
package gocql

import (
"github.com/stretchr/testify/require"
"testing"
"time"
)
Expand Down Expand Up @@ -84,3 +85,84 @@ func TestBatch_WithTimestamp(t *testing.T) {
t.Errorf("got ts %d, expected %d", storedTs, micros)
}
}

func TestBatch_WithNowInSeconds(t *testing.T) {
session := createSession(t)
defer session.Close()

if session.cfg.ProtoVersion < protoVersion5 {
t.Skip("Batch now in seconds are only available on protocol >= 5")
}

if err := createTable(session, `CREATE TABLE batch_now_in_seconds (id int primary key, val text)`); err != nil {
t.Fatal(err)
}

b := session.NewBatch(LoggedBatch)
b.WithNowInSeconds(0)
b.Query("INSERT INTO batch_now_in_seconds (id, val) VALUES (?, ?) USING TTL 20", 1, "val")
if err := session.ExecuteBatch(b); err != nil {
t.Fatal(err)
}

var remainingTTL int
err := session.Query(`SELECT TTL(val) FROM batch_now_in_seconds WHERE id = ?`, 1).
WithNowInSeconds(10).
Scan(&remainingTTL)
if err != nil {
t.Fatal(err)
}

require.Equal(t, remainingTTL, 10)
}

func TestBatch_SetKeyspace(t *testing.T) {
session := createSession(t)
defer session.Close()

if session.cfg.ProtoVersion < protoVersion5 {
t.Skip("keyspace for BATCH message is not supported in protocol < 5")
}

const keyspaceStmt = `
CREATE KEYSPACE IF NOT EXISTS gocql_keyspace_override_test
WITH replication = {
'class': 'SimpleStrategy',
'replication_factor': '1'
};
`

err := session.Query(keyspaceStmt).Exec()
if err != nil {
t.Fatal(err)
}

err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_keyspace_override_test.batch_keyspace(id int, value text, PRIMARY KEY (id))")
if err != nil {
t.Fatal(err)
}

ids := []int{1, 2}
texts := []string{"val1", "val2"}

b := session.NewBatch(LoggedBatch).SetKeyspace("gocql_keyspace_override_test")
b.Query("INSERT INTO batch_keyspace(id, value) VALUES (?, ?)", ids[0], texts[0])
b.Query("INSERT INTO batch_keyspace(id, value) VALUES (?, ?)", ids[1], texts[1])
err = session.ExecuteBatch(b)
if err != nil {
t.Fatal(err)
}

var (
id int
text string
)

iter := session.Query("SELECT * FROM gocql_keyspace_override_test.batch_keyspace").Iter()
defer iter.Close()

for i := 0; iter.Scan(&id, &text); i++ {
require.Equal(t, id, ids[i])
require.Equal(t, text, texts[i])
}
}
148 changes: 148 additions & 0 deletions cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"context"
"errors"
"fmt"
"github.com/stretchr/testify/require"
"io"
"math"
"math/big"
Expand Down Expand Up @@ -3288,3 +3289,150 @@ func TestQuery_NamedValues(t *testing.T) {
t.Fatal(err)
}
}

func TestQuery_WithNowInSeconds(t *testing.T) {
session := createSession(t)
defer session.Close()

if session.cfg.ProtoVersion < protoVersion5 {
t.Skip("Query now in seconds are only available on protocol >= 5")
}

if err := createTable(session, `CREATE TABLE query_now_in_seconds (id int primary key, val text)`); err != nil {
t.Fatal(err)
}

err := session.Query("INSERT INTO query_now_in_seconds (id, val) VALUES (?, ?) USING TTL 20", 1, "val").
WithNowInSeconds(int(0)).
Exec()
if err != nil {
t.Fatal(err)
}

var remainingTTL int
err = session.Query(`SELECT TTL(val) FROM query_now_in_seconds WHERE id = ?`, 1).
WithNowInSeconds(10).
Scan(&remainingTTL)
if err != nil {
t.Fatal(err)
}

require.Equal(t, remainingTTL, 10)
}

func TestQuery_SetKeyspace(t *testing.T) {
session := createSession(t)
defer session.Close()

if session.cfg.ProtoVersion < protoVersion5 {
t.Skip("keyspace for QUERY message is not supported in protocol < 5")
}

const keyspaceStmt = `
CREATE KEYSPACE IF NOT EXISTS gocql_query_keyspace_override_test
WITH replication = {
'class': 'SimpleStrategy',
'replication_factor': '1'
};
`

err := session.Query(keyspaceStmt).Exec()
if err != nil {
t.Fatal(err)
}

err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_query_keyspace_override_test.query_keyspace(id int, value text, PRIMARY KEY (id))")
if err != nil {
t.Fatal(err)
}

expectedID := 1
expectedText := "text"

// Testing PREPARE message
err = session.Query("INSERT INTO gocql_query_keyspace_override_test.query_keyspace (id, value) VALUES (?, ?)", expectedID, expectedText).Exec()
if err != nil {
t.Fatal(err)
}

var (
id int
text string
)

q := session.Query("SELECT * FROM gocql_query_keyspace_override_test.query_keyspace").
SetKeyspace("gocql_query_keyspace_override_test")
err = q.Scan(&id, &text)
if err != nil {
t.Fatal(err)
}

require.Equal(t, expectedID, id)
require.Equal(t, expectedText, text)

// Testing QUERY message
id = 0
text = ""

q = session.Query("SELECT * FROM gocql_query_keyspace_override_test.query_keyspace").
SetKeyspace("gocql_query_keyspace_override_test")
q.skipPrepare = true
err = q.Scan(&id, &text)
if err != nil {
t.Fatal(err)
}

require.Equal(t, expectedID, id)
require.Equal(t, expectedText, text)
}

func TestLargeSizeQuery(t *testing.T) {
session := createSession(t)
defer session.Close()

if err := createTable(session, "CREATE TABLE gocql_test.large_size_query(id int, text_col text, PRIMARY KEY (id))"); err != nil {
t.Fatal(err)
}

defer session.Close()

longString := strings.Repeat("a", 500_000)

err := session.Query("INSERT INTO gocql_test.large_size_query (id, text_col) VALUES (?, ?)", "1", longString).Exec()
if err != nil {
t.Fatal(err)
}

var result string
err = session.Query("SELECT text_col FROM gocql_test.large_size_query").Scan(&result)
if err != nil {
t.Fatal(err)
}

require.Equal(t, longString, result)
}

func TestQueryCompressionNotWorthIt(t *testing.T) {
session := createSession(t)
defer session.Close()

if err := createTable(session, "CREATE TABLE gocql_test.compression_now_worth_it(id int, text_col text, PRIMARY KEY (id))"); err != nil {
t.Fatal(err)
}

defer session.Close()

str := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890!@#$%^&*()_+"
err := session.Query("INSERT INTO gocql_test.large_size_query (id, text_col) VALUES (?, ?)", "1", str).Exec()
if err != nil {
t.Fatal(err)
}

var result string
err = session.Query("SELECT text_col FROM gocql_test.large_size_query").Scan(&result)
if err != nil {
t.Fatal(err)
}

require.Equal(t, str, result)
}
2 changes: 2 additions & 0 deletions common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ func createCluster(opts ...func(*ClusterConfig)) *ClusterConfig {
switch *flagCompressTest {
case "snappy":
cluster.Compressor = &SnappyCompressor{}
case "lz4":
cluster.Compressor = &LZ4Compressor{}
case "":
default:
panic("invalid compressor: " + *flagCompressTest)
Expand Down
52 changes: 52 additions & 0 deletions compressor.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@
package gocql

import (
"encoding/binary"
"fmt"
"github.com/golang/snappy"
"github.com/pierrec/lz4/v4"
)

type Compressor interface {
Name() string
Encode(data []byte) ([]byte, error)
Decode(data []byte) ([]byte, error)
DecodeSized(data []byte, size uint32) ([]byte, error)
}

// SnappyCompressor implements the Compressor interface and can be used to
Expand All @@ -50,3 +54,51 @@ func (s SnappyCompressor) Encode(data []byte) ([]byte, error) {
func (s SnappyCompressor) Decode(data []byte) ([]byte, error) {
return snappy.Decode(nil, data)
}

func (s SnappyCompressor) DecodeSized(data []byte, size uint32) ([]byte, error) {
buf := make([]byte, size)
return snappy.Decode(buf, data)
}

type LZ4Compressor struct{}

func (s LZ4Compressor) Name() string {
return "lz4"
}

func (s LZ4Compressor) Encode(data []byte) ([]byte, error) {
buf := make([]byte, lz4.CompressBlockBound(len(data)+4))
var compressor lz4.Compressor
n, err := compressor.CompressBlock(data, buf[4:])
// According to lz4.CompressBlock doc, it doesn't fail as long as the dst
// buffer length is at least lz4.CompressBlockBound(len(data))) bytes, but
// we check for error anyway just to be thorough.
if err != nil {
return nil, err
}
binary.BigEndian.PutUint32(buf, uint32(len(data)))
return buf[:n+4], nil
}

func (s LZ4Compressor) Decode(data []byte) ([]byte, error) {
if len(data) < 4 {
return nil, fmt.Errorf("cassandra lz4 block size should be >4, got=%d", len(data))
}
uncompressedLength := binary.BigEndian.Uint32(data)
if uncompressedLength == 0 {
return nil, nil
}
buf := make([]byte, uncompressedLength)
n, err := lz4.UncompressBlock(data[4:], buf)
return buf[:n], err
}

func (s LZ4Compressor) DecodeSized(data []byte, size uint32) ([]byte, error) {
buf := make([]byte, size)
_, err := lz4.UncompressBlock(data, buf)
if err != nil {
return nil, err
}

return buf, nil
}
Loading

0 comments on commit 47a3ae9

Please sign in to comment.