Skip to content

Commit

Permalink
feat: 1. add support op_code_module_aux(247) 2. set default skip read…
Browse files Browse the repository at this point in the history
… module data
  • Loading branch information
zonewave authored and HDT3213 committed Nov 22, 2023
1 parent b965ee0 commit f774ec2
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 70 deletions.
7 changes: 7 additions & 0 deletions core/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ const (
)

const (
opCodeModuleAux = 247 /* Module auxiliary data. */
opCodeIdle = 248 /* LRU idle time. */
opCodeFreq = 249 /* LFU frequency. */
opCodeAux = 250 /* RDB aux field. */
Expand Down Expand Up @@ -395,6 +396,12 @@ func (dec *Decoder) parse(cb func(object model.RedisObject) bool) error {
return err
}
continue
} else if b == opCodeModuleAux {
_, _, err = dec.readModuleType()
if err != nil {
return err
}
continue
}
key, err := dec.readString()
if err != nil {
Expand Down
40 changes: 39 additions & 1 deletion core/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type ModuleTypeHandler interface {
ReadOpcode() (Opcode, error)
ReadUInt() (uint64, error)
ReadSInt() (int64, error)
ReadFloat32() (float32, error)
ReadDouble() (float64, error)
ReadString() ([]byte, error)
ReadLength() (uint64, bool, error)
Expand Down Expand Up @@ -60,6 +61,9 @@ func (m moduleTypeHandlerImpl) ReadSInt() (int64, error) {
return int64(val), err
}

func (m moduleTypeHandlerImpl) ReadFloat32() (float32, error) {
return m.dec.readFloat32()
}
func (m moduleTypeHandlerImpl) ReadDouble() (float64, error) {
return m.dec.readFloat()
}
Expand All @@ -86,7 +90,8 @@ func (dec *Decoder) handleModuleType(moduleId uint64) (string, interface{}, erro
moduleType := moduleTypeNameByID(moduleId)
handler, found := dec.withSpecialTypes[moduleType]
if !found {
return moduleType, nil, fmt.Errorf("unknown module type: %s", moduleType)
fmt.Printf("unknown module type: %s,will skip\n", moduleType)
handler = skipModuleAuxData
}
encVersion := moduleTypeEncVersionByID(moduleId)
val, err := handler(moduleTypeHandlerImpl{dec: dec}, int(encVersion))
Expand All @@ -108,4 +113,37 @@ func moduleTypeEncVersionByID(moduleId uint64) uint64 {
return moduleId & 1023
}

// skipModuleAuxData skips module aux data
func skipModuleAuxData(h ModuleTypeHandler, _ int) (interface{}, error) {
opCode, err := h.ReadOpcode()
if err != nil {
return nil, err
}
for opCode != ModuleOpcodeEOF {
switch opCode {
case ModuleOpcodeSInt:
_, err = h.ReadSInt()
case ModuleOpcodeUInt:
_, err = h.ReadUInt()
case ModuleOpcodeFloat:
_, err = h.ReadFloat32()
case ModuleOpcodeDouble:
_, err = h.ReadDouble()
case ModuleOpcodeString:
_, err = h.ReadString()
default:
err = fmt.Errorf("unknown module opcode %d", opCode)
}
if err != nil {
return nil, err
}
opCode, err = h.ReadOpcode()
if err != nil {
return nil, err
}
}

return nil, nil
}

const ModuleTypeNameCharSet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
168 changes: 99 additions & 69 deletions core/module_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,87 +84,117 @@ func TestModuleType(t *testing.T) {
t.Error(err)
return
}

expectedResult := "expected-result"

dec := NewDecoder(buf).WithSpecialType(testModuleType,
func(h ModuleTypeHandler, encVersion int) (interface{}, error) {
if encVersion != 42 {
t.Errorf("invalid encoding version, expected %d, actual %d",
expectedModuleEncVersion, encVersion)
return nil, fmt.Errorf("invalid encoding version, expected %d, actual %d",
expectedModuleEncVersion, encVersion)
}

opcode, err := h.ReadOpcode()
if err != nil {
return nil, err
}
if opcode != ModuleOpcodeString {
return nil, fmt.Errorf("invalid opcode read, expected %d (string), actual %d",
ModuleOpcodeString, opcode)
t.Run("with parse", func(t *testing.T) {

expectedResult := "expected-result"

dec := NewDecoder(buf).WithSpecialType(testModuleType,
func(h ModuleTypeHandler, encVersion int) (interface{}, error) {
if encVersion != 42 {
t.Errorf("invalid encoding version, expected %d, actual %d",
expectedModuleEncVersion, encVersion)
return nil, fmt.Errorf("invalid encoding version, expected %d, actual %d",
expectedModuleEncVersion, encVersion)
}

opcode, err := h.ReadOpcode()
if err != nil {
return nil, err
}
if opcode != ModuleOpcodeString {
return nil, fmt.Errorf("invalid opcode read, expected %d (string), actual %d",
ModuleOpcodeString, opcode)
}
data, err := h.ReadString()
if err != nil {
return nil, err
}
if !bytes.Equal(data, []byte(expectedStrData)) {
return nil, fmt.Errorf("invalid string data read, expected %s, actual %s",
expectedStrData, string(data))
}

opcode, err = h.ReadOpcode()
if err != nil {
return nil, err
}
if opcode != ModuleOpcodeUInt {
return nil, fmt.Errorf("invalid opcode read, expected %d (uint), actual %d",
ModuleOpcodeUInt, opcode)
}
val, err := h.ReadUInt()
if err != nil {
return nil, err
}
if val != expectedUInt {
return nil, fmt.Errorf("invalid unsigned int read, expected %d, actual %d",
expectedUInt, val)
}
opcode, err = h.ReadOpcode()
if err != nil {
return nil, err
}
if opcode != ModuleOpcodeEOF {
return nil, fmt.Errorf("invalid opcode read, expected %d (EOF), actual %d",
ModuleOpcodeEOF, opcode)
}
return expectedResult, nil
})

err = dec.Parse(func(o model.RedisObject) bool {
if o.GetKey() != key {
t.Errorf("invalid object key, expected %s, actual %s", key, o.GetKey())
return false
}
data, err := h.ReadString()
if err != nil {
return nil, err
if o.GetType() != testModuleType {
t.Errorf("invalid redis type, expected %s, actual %s", testModuleType, o.GetType())
return false
}
if !bytes.Equal(data, []byte(expectedStrData)) {
return nil, fmt.Errorf("invalid string data read, expected %s, actual %s",
expectedStrData, string(data))
mtObj, ok := o.(*model.ModuleTypeObject)
if !ok {
t.Errorf("invalid object type, expected model.ModuleTypeObject")
return false
}

opcode, err = h.ReadOpcode()
if err != nil {
return nil, err
}
if opcode != ModuleOpcodeUInt {
return nil, fmt.Errorf("invalid opcode read, expected %d (uint), actual %d",
ModuleOpcodeUInt, opcode)
if mtObj.Value != expectedResult {
t.Errorf("invalid return value")
return false
}
val, err := h.ReadUInt()
if err != nil {
return nil, err
}
if val != expectedUInt {
return nil, fmt.Errorf("invalid unsigned int read, expected %d, actual %d",
expectedUInt, val)

return true
})
if err != nil {
t.Error(err)
}
})
t.Run("skip parse", func(t *testing.T) {
dec := NewDecoder(buf)
err = dec.Parse(func(o model.RedisObject) bool {
if o.GetKey() != key {
t.Errorf("invalid object key, expected %s, actual %s", key, o.GetKey())
return false
}
opcode, err = h.ReadOpcode()
if err != nil {
return nil, err
if o.GetType() != testModuleType {
t.Errorf("invalid redis type, expected %s, actual %s", testModuleType, o.GetType())
return false
}
if opcode != ModuleOpcodeEOF {
return nil, fmt.Errorf("invalid opcode read, expected %d (EOF), actual %d",
ModuleOpcodeEOF, opcode)
mtObj, ok := o.(*model.ModuleTypeObject)
if !ok {
t.Errorf("invalid object type, expected model.ModuleTypeObject")
return false
}
return expectedResult, nil
})

err = dec.Parse(func(o model.RedisObject) bool {
if o.GetKey() != key {
t.Errorf("invalid object key, expected %s, actual %s", key, o.GetKey())
return false
}
if o.GetType() != testModuleType {
t.Errorf("invalid redis type, expected %s, actual %s", testModuleType, o.GetType())
return false
}
mtObj, ok := o.(*model.ModuleTypeObject)
if !ok {
t.Errorf("invalid object type, expected model.ModuleTypeObject")
return false
}
if mtObj.Value != nil {
t.Errorf("invalid return value")
return false
}

if mtObj.Value != expectedResult {
t.Errorf("invalid return value")
return false
return true
})
if err != nil {
t.Error(err)
}

return true
})
if err != nil {
t.Error(err)
}
}

func TestCorrectModuleTypeEncodeDecode(t *testing.T) {
Expand Down
8 changes: 8 additions & 0 deletions core/string.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ func (dec *Decoder) readFloat() (float64, error) {
bits := binary.LittleEndian.Uint64(dec.buffer)
return math.Float64frombits(bits), nil
}
func (dec *Decoder) readFloat32() (f float32, err error) {
err = dec.readFull(dec.buffer[:4])
if err != nil {
return 0, err
}
bits := binary.LittleEndian.Uint32(dec.buffer[:4])
return math.Float32frombits(bits), nil
}

func (dec *Decoder) readLZF() ([]byte, error) {
inLen, _, err := dec.readLength()
Expand Down

0 comments on commit f774ec2

Please sign in to comment.