From 47bdac24a3e42a55c26561c4fd2ff2421c55f8a7 Mon Sep 17 00:00:00 2001 From: Achille Date: Mon, 8 Nov 2021 17:23:11 -0800 Subject: [PATCH] thrift (#105) * add thrift package * add encoder and decoder APIs * add a bit of documentation * add enum struct field tag * add a test for the enum struct type * use ioutil.Discard to be compatible with Go 1.14 * add support for union types * don't use IsExported to be compatible with older Go versions * add support for required fields * improve union model * improve error reporting * improve error reporting (2) * add a strict mode to the decoder * add debug reader and writer * use reflect.Value.IsZero * simplify API + fix use of non-zigzag varint encoding * simplify and optimize reading of BINARY type * fixes * test that io.EOF is returned when decoding an empty input * optimize required fields validation * cleanup * fix missing field index * consider anonymous embedded structs * rename error fields to 'cause' * remove magic support for unsigned integer values --- Makefile | 5 +- thrift/binary.go | 367 ++++++++++++++++++++++ thrift/compact.go | 346 +++++++++++++++++++++ thrift/debug.go | 230 ++++++++++++++ thrift/decode.go | 661 ++++++++++++++++++++++++++++++++++++++++ thrift/decode_test.go | 19 ++ thrift/encode.go | 382 +++++++++++++++++++++++ thrift/error.go | 111 +++++++ thrift/protocol.go | 73 +++++ thrift/protocol_test.go | 204 +++++++++++++ thrift/struct.go | 132 ++++++++ thrift/thrift.go | 164 ++++++++++ thrift/thrift_test.go | 298 ++++++++++++++++++ thrift/unsafe.go | 24 ++ 14 files changed, 3015 insertions(+), 1 deletion(-) create mode 100644 thrift/binary.go create mode 100644 thrift/compact.go create mode 100644 thrift/debug.go create mode 100644 thrift/decode.go create mode 100644 thrift/decode_test.go create mode 100644 thrift/encode.go create mode 100644 thrift/error.go create mode 100644 thrift/protocol.go create mode 100644 thrift/protocol_test.go create mode 100644 thrift/struct.go create mode 100644 thrift/thrift.go create mode 100644 thrift/thrift_test.go create mode 100644 thrift/unsafe.go diff --git a/Makefile b/Makefile index 18fb2d8..892c83c 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ go-fuzz-build := ${GOPATH}/bin/go-fuzz-build go-fuzz-corpus := ${GOPATH}/src/github.com/dvyukov/go-fuzz-corpus go-fuzz-dep := ${GOPATH}/src/github.com/dvyukov/go-fuzz/go-fuzz-dep -test: test-ascii test-json test-json-bugs test-json-1.17 test-proto test-iso8601 +test: test-ascii test-json test-json-bugs test-json-1.17 test-proto test-iso8601 test-thrift test-ascii: go test -cover -race ./ascii @@ -30,6 +30,9 @@ test-proto: test-iso8601: go test -cover -race ./iso8601 +test-thrift: + go test -cover -race ./thrift + $(benchstat): GO111MODULE=off go get -u golang.org/x/perf/cmd/benchstat diff --git a/thrift/binary.go b/thrift/binary.go new file mode 100644 index 0000000..18d95d9 --- /dev/null +++ b/thrift/binary.go @@ -0,0 +1,367 @@ +package thrift + +import ( + "bufio" + "bytes" + "encoding/binary" + "fmt" + "io" + "math" +) + +// BinaryProtocol is a Protocol implementation for the binary thrift protocol. +// +// https://github.com/apache/thrift/blob/master/doc/specs/thrift-binary-protocol.md +type BinaryProtocol struct { + NonStrict bool +} + +func (p *BinaryProtocol) NewReader(r io.Reader) Reader { + return &binaryReader{p: p, r: r} +} + +func (p *BinaryProtocol) NewWriter(w io.Writer) Writer { + return &binaryWriter{p: p, w: w} +} + +func (p *BinaryProtocol) Features() Features { + return 0 +} + +type binaryReader struct { + p *BinaryProtocol + r io.Reader + b [8]byte +} + +func (r *binaryReader) Protocol() Protocol { + return r.p +} + +func (r *binaryReader) Reader() io.Reader { + return r.r +} + +func (r *binaryReader) ReadBool() (bool, error) { + v, err := r.ReadByte() + return v != 0, err +} + +func (r *binaryReader) ReadInt8() (int8, error) { + b, err := r.ReadByte() + return int8(b), err +} + +func (r *binaryReader) ReadInt16() (int16, error) { + b, err := r.read(2) + if len(b) < 2 { + return 0, err + } + return int16(binary.BigEndian.Uint16(b)), nil +} + +func (r *binaryReader) ReadInt32() (int32, error) { + b, err := r.read(4) + if len(b) < 4 { + return 0, err + } + return int32(binary.BigEndian.Uint32(b)), nil +} + +func (r *binaryReader) ReadInt64() (int64, error) { + b, err := r.read(8) + if len(b) < 8 { + return 0, err + } + return int64(binary.BigEndian.Uint64(b)), nil +} + +func (r *binaryReader) ReadFloat64() (float64, error) { + b, err := r.read(8) + if len(b) < 8 { + return 0, err + } + return math.Float64frombits(binary.BigEndian.Uint64(b)), nil +} + +func (r *binaryReader) ReadBytes() ([]byte, error) { + n, err := r.ReadLength() + if err != nil { + return nil, err + } + b := make([]byte, n) + _, err = io.ReadFull(r.r, b) + return b, err +} + +func (r *binaryReader) ReadString() (string, error) { + b, err := r.ReadBytes() + return unsafeBytesToString(b), err +} + +func (r *binaryReader) ReadLength() (int, error) { + b, err := r.read(4) + if len(b) < 4 { + return 0, err + } + n := binary.BigEndian.Uint32(b) + if n > math.MaxInt32 { + return 0, fmt.Errorf("length out of range: %d", n) + } + return int(n), nil +} + +func (r *binaryReader) ReadMessage() (Message, error) { + m := Message{} + + b, err := r.read(4) + if len(b) < 4 { + return m, err + } + + if (b[0] >> 7) == 0 { // non-strict + n := int(binary.BigEndian.Uint32(b)) + s := make([]byte, n) + _, err := io.ReadFull(r.r, s) + if err != nil { + return m, dontExpectEOF(err) + } + m.Name = unsafeBytesToString(s) + + t, err := r.ReadInt8() + if err != nil { + return m, dontExpectEOF(err) + } + + m.Type = MessageType(t & 0x7) + } else { + m.Type = MessageType(b[3] & 0x7) + + if m.Name, err = r.ReadString(); err != nil { + return m, dontExpectEOF(err) + } + } + + m.SeqID, err = r.ReadInt32() + return m, err +} + +func (r *binaryReader) ReadField() (Field, error) { + t, err := r.ReadInt8() + if err != nil { + return Field{}, err + } + i, err := r.ReadInt16() + if err != nil { + return Field{}, err + } + return Field{ID: i, Type: Type(t)}, nil +} + +func (r *binaryReader) ReadList() (List, error) { + t, err := r.ReadInt8() + if err != nil { + return List{}, err + } + n, err := r.ReadInt32() + if err != nil { + return List{}, dontExpectEOF(err) + } + return List{Size: n, Type: Type(t)}, nil +} + +func (r *binaryReader) ReadSet() (Set, error) { + l, err := r.ReadList() + return Set(l), err +} + +func (r *binaryReader) ReadMap() (Map, error) { + k, err := r.ReadByte() + if err != nil { + return Map{}, err + } + v, err := r.ReadByte() + if err != nil { + return Map{}, dontExpectEOF(err) + } + n, err := r.ReadInt32() + if err != nil { + return Map{}, dontExpectEOF(err) + } + return Map{Size: n, Key: Type(k), Value: Type(v)}, nil +} + +func (r *binaryReader) ReadByte() (byte, error) { + switch x := r.r.(type) { + case *bytes.Buffer: + return x.ReadByte() + case *bytes.Reader: + return x.ReadByte() + case *bufio.Reader: + return x.ReadByte() + case io.ByteReader: + return x.ReadByte() + default: + b, err := r.read(1) + if err != nil { + return 0, err + } + return b[0], nil + } +} + +func (r *binaryReader) read(n int) ([]byte, error) { + _, err := io.ReadFull(r.r, r.b[:n]) + return r.b[:n], err +} + +type binaryWriter struct { + p *BinaryProtocol + b [8]byte + w io.Writer +} + +func (w *binaryWriter) Protocol() Protocol { + return w.p +} + +func (w *binaryWriter) Writer() io.Writer { + return w.w +} + +func (w *binaryWriter) WriteBool(v bool) error { + var b byte + if v { + b = 1 + } + return w.writeByte(b) +} + +func (w *binaryWriter) WriteInt8(v int8) error { + return w.writeByte(byte(v)) +} + +func (w *binaryWriter) WriteInt16(v int16) error { + binary.BigEndian.PutUint16(w.b[:2], uint16(v)) + return w.write(w.b[:2]) +} + +func (w *binaryWriter) WriteInt32(v int32) error { + binary.BigEndian.PutUint32(w.b[:4], uint32(v)) + return w.write(w.b[:4]) +} + +func (w *binaryWriter) WriteInt64(v int64) error { + binary.BigEndian.PutUint64(w.b[:8], uint64(v)) + return w.write(w.b[:8]) +} + +func (w *binaryWriter) WriteFloat64(v float64) error { + binary.BigEndian.PutUint64(w.b[:8], math.Float64bits(v)) + return w.write(w.b[:8]) +} + +func (w *binaryWriter) WriteBytes(v []byte) error { + if err := w.WriteLength(len(v)); err != nil { + return err + } + return w.write(v) +} + +func (w *binaryWriter) WriteString(v string) error { + if err := w.WriteLength(len(v)); err != nil { + return err + } + return w.writeString(v) +} + +func (w *binaryWriter) WriteLength(n int) error { + if n < 0 { + return fmt.Errorf("negative length cannot be encoded in thrift: %d", n) + } + if n > math.MaxInt32 { + return fmt.Errorf("length is too large to be encoded in thrift: %d", n) + } + return w.WriteInt32(int32(n)) +} + +func (w *binaryWriter) WriteMessage(m Message) error { + if w.p.NonStrict { + if err := w.WriteString(m.Name); err != nil { + return err + } + if err := w.writeByte(byte(m.Type)); err != nil { + return err + } + } else { + w.b[0] = 1 << 7 + w.b[1] = 0 + w.b[2] = 0 + w.b[3] = byte(m.Type) & 0x7 + binary.BigEndian.PutUint32(w.b[4:], uint32(len(m.Name))) + + if err := w.write(w.b[:8]); err != nil { + return err + } + if err := w.writeString(m.Name); err != nil { + return err + } + } + return w.WriteInt32(m.SeqID) +} + +func (w *binaryWriter) WriteField(f Field) error { + if err := w.writeByte(byte(f.Type)); err != nil { + return err + } + return w.WriteInt16(f.ID) +} + +func (w *binaryWriter) WriteList(l List) error { + if err := w.writeByte(byte(l.Type)); err != nil { + return err + } + return w.WriteInt32(l.Size) +} + +func (w *binaryWriter) WriteSet(s Set) error { + return w.WriteList(List(s)) +} + +func (w *binaryWriter) WriteMap(m Map) error { + if err := w.writeByte(byte(m.Key)); err != nil { + return err + } + if err := w.writeByte(byte(m.Value)); err != nil { + return err + } + return w.WriteInt32(m.Size) +} + +func (w *binaryWriter) write(b []byte) error { + _, err := w.w.Write(b) + return err +} + +func (w *binaryWriter) writeString(s string) error { + _, err := io.WriteString(w.w, s) + return err +} + +func (w *binaryWriter) writeByte(b byte) error { + // The special cases are intended to reduce the runtime overheadof testing + // for the io.ByteWriter interface for common types. Type assertions on a + // concrete type is just a pointer comparison, instead of requiring a + // complex lookup in the type metadata. + switch x := w.w.(type) { + case *bytes.Buffer: + return x.WriteByte(b) + case *bufio.Writer: + return x.WriteByte(b) + case io.ByteWriter: + return x.WriteByte(b) + default: + w.b[0] = b + return w.write(w.b[:1]) + } +} diff --git a/thrift/compact.go b/thrift/compact.go new file mode 100644 index 0000000..6a28657 --- /dev/null +++ b/thrift/compact.go @@ -0,0 +1,346 @@ +package thrift + +import ( + "bufio" + "bytes" + "encoding/binary" + "fmt" + "io" + "math" +) + +// CompactProtocol is a Protocol implementation for the compact thrift protocol. +// +// https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#integer-encoding +type CompactProtocol struct{} + +func (p *CompactProtocol) NewReader(r io.Reader) Reader { + return &compactReader{protocol: p, binary: binaryReader{r: r}} +} + +func (p *CompactProtocol) NewWriter(w io.Writer) Writer { + return &compactWriter{protocol: p, binary: binaryWriter{w: w}} +} + +func (p *CompactProtocol) Features() Features { + return UseDeltaEncoding | CoalesceBoolFields +} + +type compactReader struct { + protocol *CompactProtocol + binary binaryReader +} + +func (r *compactReader) Protocol() Protocol { + return r.protocol +} + +func (r *compactReader) Reader() io.Reader { + return r.binary.Reader() +} + +func (r *compactReader) ReadBool() (bool, error) { + return r.binary.ReadBool() +} + +func (r *compactReader) ReadInt8() (int8, error) { + return r.binary.ReadInt8() +} + +func (r *compactReader) ReadInt16() (int16, error) { + v, err := r.readVarint("int16", math.MinInt16, math.MaxInt16) + return int16(v), err +} + +func (r *compactReader) ReadInt32() (int32, error) { + v, err := r.readVarint("int32", math.MinInt32, math.MaxInt32) + return int32(v), err +} + +func (r *compactReader) ReadInt64() (int64, error) { + return r.readVarint("int64", math.MinInt64, math.MaxInt64) +} + +func (r *compactReader) ReadFloat64() (float64, error) { + return r.binary.ReadFloat64() +} + +func (r *compactReader) ReadBytes() ([]byte, error) { + n, err := r.ReadLength() + if err != nil { + return nil, err + } + b := make([]byte, n) + _, err = io.ReadFull(r.Reader(), b) + return b, err +} + +func (r *compactReader) ReadString() (string, error) { + b, err := r.ReadBytes() + return unsafeBytesToString(b), err +} + +func (r *compactReader) ReadLength() (int, error) { + n, err := r.readUvarint("length", math.MaxInt32) + return int(n), err +} + +func (r *compactReader) ReadMessage() (Message, error) { + m := Message{} + + b0, err := r.ReadByte() + if err != nil { + return m, err + } + if b0 != 0x82 { + return m, fmt.Errorf("invalid protocol id found when reading thrift message: %#x", b0) + } + + b1, err := r.ReadByte() + if err != nil { + return m, dontExpectEOF(err) + } + + seqID, err := r.readUvarint("seq id", math.MaxInt32) + if err != nil { + return m, dontExpectEOF(err) + } + + m.Type = MessageType(b1) & 0x7 + m.SeqID = int32(seqID) + m.Name, err = r.ReadString() + return m, dontExpectEOF(err) +} + +func (r *compactReader) ReadField() (Field, error) { + f := Field{} + + b, err := r.ReadByte() + if err != nil { + return f, err + } + + if Type(b) == STOP { + return f, nil + } + + if (b >> 4) != 0 { + f = Field{ID: int16(b >> 4), Type: Type(b & 0xF), Delta: true} + } else { + i, err := r.ReadInt16() + if err != nil { + return f, dontExpectEOF(err) + } + f = Field{ID: i, Type: Type(b)} + } + + return f, nil +} + +func (r *compactReader) ReadList() (List, error) { + b, err := r.ReadByte() + if err != nil { + return List{}, err + } + if (b >> 4) != 0xF { + return List{Size: int32(b >> 4), Type: Type(b & 0xF)}, nil + } + n, err := r.readUvarint("list size", math.MaxInt32) + if err != nil { + return List{}, dontExpectEOF(err) + } + return List{Size: int32(n), Type: Type(b & 0xF)}, nil +} + +func (r *compactReader) ReadSet() (Set, error) { + l, err := r.ReadList() + return Set(l), err +} + +func (r *compactReader) ReadMap() (Map, error) { + n, err := r.readUvarint("map size", math.MaxInt32) + if err != nil { + return Map{}, err + } + if n == 0 { // empty map + return Map{}, nil + } + b, err := r.ReadByte() + if err != nil { + return Map{}, dontExpectEOF(err) + } + return Map{Size: int32(n), Key: Type(b >> 4), Value: Type(b & 0xF)}, nil +} + +func (r *compactReader) ReadByte() (byte, error) { + return r.binary.ReadByte() +} + +func (r *compactReader) readUvarint(typ string, max uint64) (uint64, error) { + var br io.ByteReader + + switch x := r.Reader().(type) { + case *bytes.Buffer: + br = x + case *bytes.Reader: + br = x + case *bufio.Reader: + br = x + case io.ByteReader: + br = x + default: + br = &r.binary + } + + u, err := binary.ReadUvarint(br) + if err == nil { + if u > max { + err = fmt.Errorf("%s varint out of range: %d > %d", typ, u, max) + } + } + return u, err +} + +func (r *compactReader) readVarint(typ string, min, max int64) (int64, error) { + var br io.ByteReader + + switch x := r.Reader().(type) { + case *bytes.Buffer: + br = x + case *bytes.Reader: + br = x + case *bufio.Reader: + br = x + case io.ByteReader: + br = x + default: + br = &r.binary + } + + v, err := binary.ReadVarint(br) + if err == nil { + if v < min || v > max { + err = fmt.Errorf("%s varint out of range: %d not in [%d;%d]", typ, v, min, max) + } + } + return v, err +} + +type compactWriter struct { + protocol *CompactProtocol + binary binaryWriter + varint [binary.MaxVarintLen64]byte +} + +func (w *compactWriter) Protocol() Protocol { + return w.protocol +} + +func (w *compactWriter) Writer() io.Writer { + return w.binary.Writer() +} + +func (w *compactWriter) WriteBool(v bool) error { + return w.binary.WriteBool(v) +} + +func (w *compactWriter) WriteInt8(v int8) error { + return w.binary.WriteInt8(v) +} + +func (w *compactWriter) WriteInt16(v int16) error { + return w.writeVarint(int64(v)) +} + +func (w *compactWriter) WriteInt32(v int32) error { + return w.writeVarint(int64(v)) +} + +func (w *compactWriter) WriteInt64(v int64) error { + return w.writeVarint(v) +} + +func (w *compactWriter) WriteFloat64(v float64) error { + return w.binary.WriteFloat64(v) +} + +func (w *compactWriter) WriteBytes(v []byte) error { + if err := w.WriteLength(len(v)); err != nil { + return err + } + return w.binary.write(v) +} + +func (w *compactWriter) WriteString(v string) error { + if err := w.WriteLength(len(v)); err != nil { + return err + } + return w.binary.writeString(v) +} + +func (w *compactWriter) WriteLength(n int) error { + if n < 0 { + return fmt.Errorf("negative length cannot be encoded in thrift: %d", n) + } + if n > math.MaxInt32 { + return fmt.Errorf("length is too large to be encoded in thrift: %d", n) + } + return w.writeUvarint(uint64(n)) +} + +func (w *compactWriter) WriteMessage(m Message) error { + if err := w.binary.writeByte(0x82); err != nil { + return err + } + if err := w.binary.writeByte(byte(m.Type)); err != nil { + return err + } + if err := w.writeUvarint(uint64(m.SeqID)); err != nil { + return err + } + return w.WriteString(m.Name) +} + +func (w *compactWriter) WriteField(f Field) error { + if f.Type == STOP { + return w.binary.writeByte(0) + } + if f.ID <= 15 { + return w.binary.writeByte(byte(f.ID<<4) | byte(f.Type)) + } + if err := w.binary.writeByte(byte(f.Type)); err != nil { + return err + } + return w.WriteInt16(f.ID) +} + +func (w *compactWriter) WriteList(l List) error { + if l.Size <= 14 { + return w.binary.writeByte(byte(l.Size<<4) | byte(l.Type)) + } + if err := w.binary.writeByte(0xF0 | byte(l.Type)); err != nil { + return err + } + return w.writeUvarint(uint64(l.Size)) +} + +func (w *compactWriter) WriteSet(s Set) error { + return w.WriteList(List(s)) +} + +func (w *compactWriter) WriteMap(m Map) error { + if err := w.writeUvarint(uint64(m.Size)); err != nil || m.Size == 0 { + return err + } + return w.binary.writeByte((byte(m.Key) << 4) | byte(m.Value)) +} + +func (w *compactWriter) writeUvarint(v uint64) error { + n := binary.PutUvarint(w.varint[:], v) + return w.binary.write(w.varint[:n]) +} + +func (w *compactWriter) writeVarint(v int64) error { + n := binary.PutVarint(w.varint[:], v) + return w.binary.write(w.varint[:n]) +} diff --git a/thrift/debug.go b/thrift/debug.go new file mode 100644 index 0000000..3bf76d7 --- /dev/null +++ b/thrift/debug.go @@ -0,0 +1,230 @@ +package thrift + +import ( + "io" + "log" +) + +func NewDebugReader(r Reader, l *log.Logger) Reader { + return &debugReader{ + r: r, + l: l, + } +} + +func NewDebugWriter(w Writer, l *log.Logger) Writer { + return &debugWriter{ + w: w, + l: l, + } +} + +type debugReader struct { + r Reader + l *log.Logger +} + +func (d *debugReader) log(method string, res interface{}, err error) { + if err != nil { + d.l.Printf("(%T).%s() → ERROR: %v", d.r, method, err) + } else { + d.l.Printf("(%T).%s() → %#v", d.r, method, res) + } +} + +func (d *debugReader) Protocol() Protocol { + return d.r.Protocol() +} + +func (d *debugReader) Reader() io.Reader { + return d.r.Reader() +} + +func (d *debugReader) ReadBool() (bool, error) { + v, err := d.r.ReadBool() + d.log("ReadBool", v, err) + return v, err +} + +func (d *debugReader) ReadInt8() (int8, error) { + v, err := d.r.ReadInt8() + d.log("ReadInt8", v, err) + return v, err +} + +func (d *debugReader) ReadInt16() (int16, error) { + v, err := d.r.ReadInt16() + d.log("ReadInt16", v, err) + return v, err +} + +func (d *debugReader) ReadInt32() (int32, error) { + v, err := d.r.ReadInt32() + d.log("ReadInt32", v, err) + return v, err +} + +func (d *debugReader) ReadInt64() (int64, error) { + v, err := d.r.ReadInt64() + d.log("ReadInt64", v, err) + return v, err +} + +func (d *debugReader) ReadFloat64() (float64, error) { + v, err := d.r.ReadFloat64() + d.log("ReadFloat64", v, err) + return v, err +} + +func (d *debugReader) ReadBytes() ([]byte, error) { + v, err := d.r.ReadBytes() + d.log("ReadBytes", v, err) + return v, err +} + +func (d *debugReader) ReadString() (string, error) { + v, err := d.r.ReadString() + d.log("ReadString", v, err) + return v, err +} + +func (d *debugReader) ReadLength() (int, error) { + v, err := d.r.ReadLength() + d.log("ReadLength", v, err) + return v, err +} + +func (d *debugReader) ReadMessage() (Message, error) { + v, err := d.r.ReadMessage() + d.log("ReadMessage", v, err) + return v, err +} + +func (d *debugReader) ReadField() (Field, error) { + v, err := d.r.ReadField() + d.log("ReadField", v, err) + return v, err +} + +func (d *debugReader) ReadList() (List, error) { + v, err := d.r.ReadList() + d.log("ReadList", v, err) + return v, err +} + +func (d *debugReader) ReadSet() (Set, error) { + v, err := d.r.ReadSet() + d.log("ReadSet", v, err) + return v, err +} + +func (d *debugReader) ReadMap() (Map, error) { + v, err := d.r.ReadMap() + d.log("ReadMap", v, err) + return v, err +} + +type debugWriter struct { + w Writer + l *log.Logger +} + +func (d *debugWriter) log(method string, arg interface{}, err error) { + if err != nil { + d.l.Printf("(%T).%s(%#v) → ERROR: %v", d.w, method, arg, err) + } else { + d.l.Printf("(%T).%s(%#v)", d.w, method, arg) + } +} + +func (d *debugWriter) Protocol() Protocol { + return d.w.Protocol() +} + +func (d *debugWriter) Writer() io.Writer { + return d.w.Writer() +} + +func (d *debugWriter) WriteBool(v bool) error { + err := d.w.WriteBool(v) + d.log("WriteBool", v, err) + return err +} + +func (d *debugWriter) WriteInt8(v int8) error { + err := d.w.WriteInt8(v) + d.log("WriteInt8", v, err) + return err +} + +func (d *debugWriter) WriteInt16(v int16) error { + err := d.w.WriteInt16(v) + d.log("WriteInt16", v, err) + return err +} + +func (d *debugWriter) WriteInt32(v int32) error { + err := d.w.WriteInt32(v) + d.log("WriteInt32", v, err) + return err +} + +func (d *debugWriter) WriteInt64(v int64) error { + err := d.w.WriteInt64(v) + d.log("WriteInt64", v, err) + return err +} + +func (d *debugWriter) WriteFloat64(v float64) error { + err := d.w.WriteFloat64(v) + d.log("WriteFloat64", v, err) + return err +} + +func (d *debugWriter) WriteBytes(v []byte) error { + err := d.w.WriteBytes(v) + d.log("WriteBytes", v, err) + return err +} + +func (d *debugWriter) WriteString(v string) error { + err := d.w.WriteString(v) + d.log("WriteString", v, err) + return err +} + +func (d *debugWriter) WriteLength(n int) error { + err := d.w.WriteLength(n) + d.log("WriteLength", n, err) + return err +} + +func (d *debugWriter) WriteMessage(m Message) error { + err := d.w.WriteMessage(m) + d.log("WriteMessage", m, err) + return err +} + +func (d *debugWriter) WriteField(f Field) error { + err := d.w.WriteField(f) + d.log("WriteField", f, err) + return err +} + +func (d *debugWriter) WriteList(l List) error { + err := d.w.WriteList(l) + d.log("WriteList", l, err) + return err +} + +func (d *debugWriter) WriteSet(s Set) error { + err := d.w.WriteSet(s) + d.log("WriteSet", s, err) + return err +} + +func (d *debugWriter) WriteMap(m Map) error { + err := d.w.WriteMap(m) + d.log("WriteMap", m, err) + return err +} diff --git a/thrift/decode.go b/thrift/decode.go new file mode 100644 index 0000000..14f6528 --- /dev/null +++ b/thrift/decode.go @@ -0,0 +1,661 @@ +package thrift + +import ( + "bufio" + "bytes" + "fmt" + "io" + "io/ioutil" + "math/bits" + "reflect" + "sync/atomic" +) + +// Unmarshal deserializes the thrift data from b to v using to the protocol p. +// +// The function errors if the data in b does not match the type of v. +// +// The function panics if v cannot be converted to a thrift representation. +func Unmarshal(p Protocol, b []byte, v interface{}) error { + br := bytes.NewReader(b) + pr := p.NewReader(br) + + if err := NewDecoder(pr).Decode(v); err != nil { + return err + } + + if n := br.Len(); n != 0 { + return fmt.Errorf("unexpected trailing bytes at the end of thrift input: %d", n) + } + + return nil +} + +type Decoder struct { + r Reader + f flags +} + +func NewDecoder(r Reader) *Decoder { + return &Decoder{r: r, f: decoderFlags(r)} +} + +func (d *Decoder) Decode(v interface{}) error { + t := reflect.TypeOf(v) + p := reflect.ValueOf(v) + + if t.Kind() != reflect.Ptr { + panic("thrift.(*Decoder).Decode: expected pointer type but got " + t.String()) + } + + t = t.Elem() + p = p.Elem() + + cache, _ := decoderCache.Load().(map[typeID]decodeFunc) + decode, _ := cache[makeTypeID(t)] + + if decode == nil { + decode = decodeFuncOf(t, make(decodeFuncCache)) + + newCache := make(map[typeID]decodeFunc, len(cache)+1) + newCache[makeTypeID(t)] = decode + for k, v := range cache { + newCache[k] = v + } + + decoderCache.Store(newCache) + } + + return decode(d.r, p, d.f) +} + +func (d *Decoder) Reset(r Reader) { + d.r = r + d.f = d.f.without(protocolFlags).with(decoderFlags(r)) +} + +func (d *Decoder) SetStrict(enabled bool) { + if enabled { + d.f = d.f.with(strict) + } else { + d.f = d.f.without(strict) + } +} + +func decoderFlags(r Reader) flags { + return flags(r.Protocol().Features() << featuresBitOffset) +} + +var decoderCache atomic.Value // map[typeID]decodeFunc + +type decodeFunc func(Reader, reflect.Value, flags) error + +type decodeFuncCache map[reflect.Type]decodeFunc + +func decodeFuncOf(t reflect.Type, seen decodeFuncCache) decodeFunc { + f := seen[t] + if f != nil { + return f + } + switch t.Kind() { + case reflect.Bool: + f = decodeBool + case reflect.Int8: + f = decodeInt8 + case reflect.Int16: + f = decodeInt16 + case reflect.Int32: + f = decodeInt32 + case reflect.Int64, reflect.Int: + f = decodeInt64 + case reflect.Float32, reflect.Float64: + f = decodeFloat64 + case reflect.String: + f = decodeString + case reflect.Slice: + if t.Elem().Kind() == reflect.Uint8 { // []byte + f = decodeBytes + } else { + f = decodeFuncSliceOf(t, seen) + } + case reflect.Map: + f = decodeFuncMapOf(t, seen) + case reflect.Struct: + f = decodeFuncStructOf(t, seen) + case reflect.Ptr: + f = decodeFuncPtrOf(t, seen) + default: + panic("type cannot be decoded in thrift: " + t.String()) + } + seen[t] = f + return f +} + +func decodeBool(r Reader, v reflect.Value, _ flags) error { + b, err := r.ReadBool() + if err != nil { + return err + } + v.SetBool(b) + return nil +} + +func decodeInt8(r Reader, v reflect.Value, _ flags) error { + i, err := r.ReadInt8() + if err != nil { + return err + } + v.SetInt(int64(i)) + return nil +} + +func decodeInt16(r Reader, v reflect.Value, _ flags) error { + i, err := r.ReadInt16() + if err != nil { + return err + } + v.SetInt(int64(i)) + return nil +} + +func decodeInt32(r Reader, v reflect.Value, _ flags) error { + i, err := r.ReadInt32() + if err != nil { + return err + } + v.SetInt(int64(i)) + return nil +} + +func decodeInt64(r Reader, v reflect.Value, _ flags) error { + i, err := r.ReadInt64() + if err != nil { + return err + } + v.SetInt(int64(i)) + return nil +} + +func decodeFloat64(r Reader, v reflect.Value, _ flags) error { + f, err := r.ReadFloat64() + if err != nil { + return err + } + v.SetFloat(f) + return nil +} + +func decodeString(r Reader, v reflect.Value, _ flags) error { + s, err := r.ReadString() + if err != nil { + return err + } + v.SetString(s) + return nil +} + +func decodeBytes(r Reader, v reflect.Value, _ flags) error { + b, err := r.ReadBytes() + if err != nil { + return err + } + v.SetBytes(b) + return nil +} + +func decodeFuncSliceOf(t reflect.Type, seen decodeFuncCache) decodeFunc { + elem := t.Elem() + typ := TypeOf(elem) + dec := decodeFuncOf(elem, seen) + + return func(r Reader, v reflect.Value, flags flags) error { + l, err := r.ReadList() + if err != nil { + return err + } + + // TODO: implement type conversions? + if typ != l.Type { + if flags.have(strict) { + return &TypeMismatch{item: "list item", Expect: typ, Found: l.Type} + } + return nil + } + + v.Set(reflect.MakeSlice(t, int(l.Size), int(l.Size))) + flags = flags.only(decodeFlags) + + for i := 0; i < int(l.Size); i++ { + if err := dec(r, v.Index(i), flags); err != nil { + return with(dontExpectEOF(err), &decodeErrorList{cause: l, index: i}) + } + } + + return nil + } +} + +func decodeFuncMapOf(t reflect.Type, seen decodeFuncCache) decodeFunc { + key, elem := t.Key(), t.Elem() + if elem.Size() == 0 { // map[?]struct{} + return decodeFuncMapAsSetOf(t, seen) + } + + mapType := reflect.MapOf(key, elem) + keyZero := reflect.Zero(key) + elemZero := reflect.Zero(elem) + keyType := TypeOf(key) + elemType := TypeOf(elem) + decodeKey := decodeFuncOf(key, seen) + decodeElem := decodeFuncOf(elem, seen) + + return func(r Reader, v reflect.Value, flags flags) error { + m, err := r.ReadMap() + if err != nil { + return err + } + + v.Set(reflect.MakeMapWithSize(mapType, int(m.Size))) + + if m.Size == 0 { // empty map + return nil + } + + // TODO: implement type conversions? + if keyType != m.Key { + if flags.have(strict) { + return &TypeMismatch{item: "map key", Expect: keyType, Found: m.Key} + } + return nil + } + + if elemType != m.Value { + if flags.have(strict) { + return &TypeMismatch{item: "map value", Expect: elemType, Found: m.Value} + } + return nil + } + + tmpKey := reflect.New(key).Elem() + tmpElem := reflect.New(elem).Elem() + flags = flags.only(decodeFlags) + + for i := 0; i < int(m.Size); i++ { + if err := decodeKey(r, tmpKey, flags); err != nil { + return with(dontExpectEOF(err), &decodeErrorMap{cause: m, index: i}) + } + if err := decodeElem(r, tmpElem, flags); err != nil { + return with(dontExpectEOF(err), &decodeErrorMap{cause: m, index: i}) + } + v.SetMapIndex(tmpKey, tmpElem) + tmpKey.Set(keyZero) + tmpElem.Set(elemZero) + } + + return nil + } +} + +func decodeFuncMapAsSetOf(t reflect.Type, seen decodeFuncCache) decodeFunc { + key, elem := t.Key(), t.Elem() + keyZero := reflect.Zero(key) + elemZero := reflect.Zero(elem) + typ := TypeOf(key) + dec := decodeFuncOf(key, seen) + + return func(r Reader, v reflect.Value, flags flags) error { + s, err := r.ReadSet() + if err != nil { + return err + } + + v.Set(reflect.MakeMapWithSize(t, int(s.Size))) + + if s.Size == 0 { + return nil + } + + // TODO: implement type conversions? + if typ != s.Type { + if flags.have(strict) { + return &TypeMismatch{item: "list item", Expect: typ, Found: s.Type} + } + return nil + } + + tmp := reflect.New(key).Elem() + flags = flags.only(decodeFlags) + + for i := 0; i < int(s.Size); i++ { + if err := dec(r, tmp, flags); err != nil { + return with(dontExpectEOF(err), &decodeErrorSet{cause: s, index: i}) + } + v.SetMapIndex(tmp, elemZero) + tmp.Set(keyZero) + } + + return nil + } +} + +type structDecoder struct { + fields []structDecoderField + union []int + minID int16 + zero reflect.Value + required []uint64 +} + +func (dec *structDecoder) decode(r Reader, v reflect.Value, flags flags) error { + v.Set(dec.zero) + flags = flags.only(decodeFlags) + coalesceBoolFields := flags.have(coalesceBoolFields) + + lastField := reflect.Value{} + union := len(dec.union) > 0 + seen := make([]uint64, 1) + if len(dec.required) > len(seen) { + seen = make([]uint64, len(dec.required)) + } + + err := readStruct(r, func(r Reader, f Field) error { + i := int(f.ID) - int(dec.minID) + if i < 0 || i >= len(dec.fields) || dec.fields[i].decode == nil { + return skipField(r, f) + } + field := &dec.fields[i] + seen[i/64] |= 1 << (i % 64) + + // TODO: implement type conversions? + if f.Type != field.typ && !(f.Type == TRUE && field.typ == BOOL) { + if flags.have(strict) { + return &TypeMismatch{item: "field value", Expect: field.typ, Found: f.Type} + } + return nil + } + + x := v + for _, i := range field.index { + if x.Kind() == reflect.Ptr { + x = x.Elem() + } + if x = x.Field(i); x.Kind() == reflect.Ptr { + if x.IsNil() { + x.Set(reflect.New(x.Type().Elem())) + } + } + } + + if union { + v.Set(dec.zero) + } + + lastField = x + + if coalesceBoolFields && (f.Type == TRUE || f.Type == FALSE) { + x.SetBool(f.Type == TRUE) + return nil + } + + return field.decode(r, x, flags.with(field.flags)) + }) + if err != nil { + return err + } + + for i, required := range dec.required { + if mask := required & seen[i]; mask != required { + index := bits.TrailingZeros64(mask) + field := &dec.fields[i+index] + return &MissingField{Field: Field{ID: field.id, Type: field.typ}} + } + } + + if union && lastField.IsValid() { + v.FieldByIndex(dec.union).Set(lastField.Addr()) + } + + return nil +} + +type structDecoderField struct { + index []int + id int16 + flags flags + typ Type + decode decodeFunc +} + +func decodeFuncStructOf(t reflect.Type, seen decodeFuncCache) decodeFunc { + dec := &structDecoder{ + zero: reflect.Zero(t), + } + decode := dec.decode + seen[t] = decode + + fields := make([]structDecoderField, 0, t.NumField()) + forEachStructField(t, nil, func(f structField) { + if f.flags.have(union) { + dec.union = f.index + } else { + fields = append(fields, structDecoderField{ + index: f.index, + id: f.id, + flags: f.flags, + typ: TypeOf(f.typ), + decode: decodeFuncStructFieldOf(f, seen), + }) + } + }) + + minID := int16(0) + maxID := int16(0) + + for _, f := range fields { + if f.id < minID || minID == 0 { + minID = f.id + } + if f.id > maxID { + maxID = f.id + } + } + + dec.fields = make([]structDecoderField, (maxID-minID)+1) + dec.minID = minID + dec.required = make([]uint64, len(fields)/64+1) + + for _, f := range fields { + i := f.id - minID + p := dec.fields[i] + if p.decode != nil { + panic(fmt.Errorf("thrift struct field id %d is present multiple times in %s with types %s and %s", f.id, t, p.typ, f.typ)) + } + dec.fields[i] = f + if f.flags.have(required) { + dec.required[i/64] |= 1 << (i % 64) + } + } + + return decode +} + +func decodeFuncStructFieldOf(f structField, seen decodeFuncCache) decodeFunc { + if f.flags.have(enum) { + switch f.typ.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return decodeInt32 + } + } + return decodeFuncOf(f.typ, seen) +} + +func decodeFuncPtrOf(t reflect.Type, seen decodeFuncCache) decodeFunc { + elem := t.Elem() + decode := decodeFuncOf(t.Elem(), seen) + return func(r Reader, v reflect.Value, f flags) error { + if v.IsNil() { + v.Set(reflect.New(elem)) + } + return decode(r, v.Elem(), f) + } +} + +func readBinary(r Reader, f func(io.Reader) error) error { + n, err := r.ReadLength() + if err != nil { + return err + } + return dontExpectEOF(f(io.LimitReader(r.Reader(), int64(n)))) +} + +func readList(r Reader, f func(Reader, Type) error) error { + l, err := r.ReadList() + if err != nil { + return err + } + + for i := 0; i < int(l.Size); i++ { + if err := f(r, l.Type); err != nil { + return with(dontExpectEOF(err), &decodeErrorList{cause: l, index: i}) + } + } + + return nil +} + +func readSet(r Reader, f func(Reader, Type) error) error { + s, err := r.ReadSet() + if err != nil { + return err + } + + for i := 0; i < int(s.Size); i++ { + if err := f(r, s.Type); err != nil { + return with(dontExpectEOF(err), &decodeErrorSet{cause: s, index: i}) + } + } + + return nil +} + +func readMap(r Reader, f func(Reader, Type, Type) error) error { + m, err := r.ReadMap() + if err != nil { + return err + } + + for i := 0; i < int(m.Size); i++ { + if err := f(r, m.Key, m.Value); err != nil { + return with(dontExpectEOF(err), &decodeErrorMap{cause: m, index: i}) + } + } + + return nil +} + +func readStruct(r Reader, f func(Reader, Field) error) error { + lastFieldID := int16(0) + numFields := 0 + + for { + x, err := r.ReadField() + if err != nil { + if numFields > 0 { + err = dontExpectEOF(err) + } + return err + } + + if x.Type == STOP { + return nil + } + + if x.Delta { + x.ID += lastFieldID + x.Delta = false + } + + if err := f(r, x); err != nil { + return with(dontExpectEOF(err), &decodeErrorField{cause: x}) + } + + lastFieldID = x.ID + numFields++ + } +} + +func skip(r Reader, t Type) error { + var err error + switch t { + case TRUE, FALSE: + _, err = r.ReadBool() + case I8: + _, err = r.ReadInt8() + case I16: + _, err = r.ReadInt16() + case I32: + _, err = r.ReadInt32() + case I64: + _, err = r.ReadInt64() + case DOUBLE: + _, err = r.ReadFloat64() + case BINARY: + err = skipBinary(r) + case LIST: + err = skipList(r) + case SET: + err = skipSet(r) + case MAP: + err = skipMap(r) + case STRUCT: + err = skipStruct(r) + default: + return fmt.Errorf("skipping unsupported thrift type %d", t) + } + return err +} + +func skipBinary(r Reader) error { + n, err := r.ReadLength() + if err != nil { + return err + } + if n == 0 { + return nil + } + switch x := r.Reader().(type) { + case *bufio.Reader: + _, err = x.Discard(int(n)) + default: + _, err = io.CopyN(ioutil.Discard, x, int64(n)) + } + return dontExpectEOF(err) +} + +func skipList(r Reader) error { + return readList(r, skip) +} + +func skipSet(r Reader) error { + return readSet(r, skip) +} + +func skipMap(r Reader) error { + return readMap(r, func(r Reader, k, v Type) error { + if err := skip(r, k); err != nil { + return dontExpectEOF(err) + } + if err := skip(r, v); err != nil { + return dontExpectEOF(err) + } + return nil + }) +} + +func skipStruct(r Reader) error { + return readStruct(r, skipField) +} + +func skipField(r Reader, f Field) error { + return skip(r, f.Type) +} diff --git a/thrift/decode_test.go b/thrift/decode_test.go new file mode 100644 index 0000000..7658c14 --- /dev/null +++ b/thrift/decode_test.go @@ -0,0 +1,19 @@ +package thrift_test + +import ( + "bytes" + "io" + "testing" + + "github.com/segmentio/encoding/thrift" +) + +func TestDecodeEOF(t *testing.T) { + p := thrift.CompactProtocol{} + d := thrift.NewDecoder(p.NewReader(bytes.NewReader(nil))) + v := struct{ Name string }{} + + if err := d.Decode(&v); err != io.EOF { + t.Errorf("unexpected error returned: %v", err) + } +} diff --git a/thrift/encode.go b/thrift/encode.go new file mode 100644 index 0000000..6faa334 --- /dev/null +++ b/thrift/encode.go @@ -0,0 +1,382 @@ +package thrift + +import ( + "bytes" + "fmt" + "math" + "reflect" + "sort" + "sync/atomic" +) + +// Marshal serializes v into a thrift representation according to the the +// protocol p. +// +// The function panics if v cannot be converted to a thrift representation. +func Marshal(p Protocol, v interface{}) ([]byte, error) { + buf := new(bytes.Buffer) + enc := NewEncoder(p.NewWriter(buf)) + err := enc.Encode(v) + return buf.Bytes(), err +} + +type Encoder struct { + w Writer + f flags +} + +func NewEncoder(w Writer) *Encoder { + return &Encoder{w: w, f: encoderFlags(w)} +} + +func (e *Encoder) Encode(v interface{}) error { + t := reflect.TypeOf(v) + cache, _ := encoderCache.Load().(map[typeID]encodeFunc) + encode, _ := cache[makeTypeID(t)] + + if encode == nil { + encode = encodeFuncOf(t, make(encodeFuncCache)) + + newCache := make(map[typeID]encodeFunc, len(cache)+1) + newCache[makeTypeID(t)] = encode + for k, v := range cache { + newCache[k] = v + } + + encoderCache.Store(newCache) + } + + return encode(e.w, reflect.ValueOf(v), e.f) +} + +func (e *Encoder) Reset(w Writer) { + e.w = w + e.f = e.f.without(protocolFlags).with(encoderFlags(w)) +} + +func encoderFlags(w Writer) flags { + return flags(w.Protocol().Features() << featuresBitOffset) +} + +var encoderCache atomic.Value // map[typeID]encodeFunc + +type encodeFunc func(Writer, reflect.Value, flags) error + +type encodeFuncCache map[reflect.Type]encodeFunc + +func encodeFuncOf(t reflect.Type, seen encodeFuncCache) encodeFunc { + f := seen[t] + if f != nil { + return f + } + switch t.Kind() { + case reflect.Bool: + f = encodeBool + case reflect.Int8: + f = encodeInt8 + case reflect.Int16: + f = encodeInt16 + case reflect.Int32: + f = encodeInt32 + case reflect.Int64, reflect.Int: + f = encodeInt64 + case reflect.Float32, reflect.Float64: + f = encodeFloat64 + case reflect.String: + f = encodeString + case reflect.Slice: + if t.Elem().Kind() == reflect.Uint8 { + f = encodeBytes + } else { + f = encodeFuncSliceOf(t, seen) + } + case reflect.Map: + f = encodeFuncMapOf(t, seen) + case reflect.Struct: + f = encodeFuncStructOf(t, seen) + case reflect.Ptr: + f = encodeFuncPtrOf(t, seen) + default: + panic("type cannot be encoded in thrift: " + t.String()) + } + seen[t] = f + return f +} + +func encodeBool(w Writer, v reflect.Value, _ flags) error { + return w.WriteBool(v.Bool()) +} + +func encodeInt8(w Writer, v reflect.Value, _ flags) error { + return w.WriteInt8(int8(v.Int())) +} + +func encodeInt16(w Writer, v reflect.Value, _ flags) error { + return w.WriteInt16(int16(v.Int())) +} + +func encodeInt32(w Writer, v reflect.Value, _ flags) error { + return w.WriteInt32(int32(v.Int())) +} + +func encodeInt64(w Writer, v reflect.Value, _ flags) error { + return w.WriteInt64(v.Int()) +} + +func encodeFloat64(w Writer, v reflect.Value, _ flags) error { + return w.WriteFloat64(v.Float()) +} + +func encodeString(w Writer, v reflect.Value, _ flags) error { + return w.WriteString(v.String()) +} + +func encodeBytes(w Writer, v reflect.Value, _ flags) error { + return w.WriteBytes(v.Bytes()) +} + +func encodeFuncSliceOf(t reflect.Type, seen encodeFuncCache) encodeFunc { + elem := t.Elem() + typ := TypeOf(elem) + enc := encodeFuncOf(elem, seen) + + return func(w Writer, v reflect.Value, flags flags) error { + n := v.Len() + if n > math.MaxInt32 { + return fmt.Errorf("slice length is too large to be represented in thrift: %d > max(int32)", n) + } + + err := w.WriteList(List{ + Size: int32(n), + Type: typ, + }) + if err != nil { + return err + } + + for i := 0; i < n; i++ { + if err := enc(w, v.Index(i), flags); err != nil { + return err + } + } + + return nil + } +} + +func encodeFuncMapOf(t reflect.Type, seen encodeFuncCache) encodeFunc { + key, elem := t.Key(), t.Elem() + if elem.Size() == 0 { // map[?]struct{} + return encodeFuncMapAsSetOf(t, seen) + } + + keyType := TypeOf(key) + elemType := TypeOf(elem) + encodeKey := encodeFuncOf(key, seen) + encodeElem := encodeFuncOf(elem, seen) + + return func(w Writer, v reflect.Value, flags flags) error { + n := v.Len() + if n > math.MaxInt32 { + return fmt.Errorf("map length is too large to be represented in thrift: %d > max(int32)", n) + } + + err := w.WriteMap(Map{ + Size: int32(n), + Key: keyType, + Value: elemType, + }) + if err != nil { + return err + } + if n == 0 { // empty map + return nil + } + + for i, iter := 0, v.MapRange(); iter.Next(); i++ { + if err := encodeKey(w, iter.Key(), flags); err != nil { + return err + } + if err := encodeElem(w, iter.Value(), flags); err != nil { + return err + } + } + + return nil + } +} + +func encodeFuncMapAsSetOf(t reflect.Type, seen encodeFuncCache) encodeFunc { + key := t.Key() + typ := TypeOf(key) + enc := encodeFuncOf(key, seen) + + return func(w Writer, v reflect.Value, flags flags) error { + n := v.Len() + if n > math.MaxInt32 { + return fmt.Errorf("map length is too large to be represented in thrift: %d > max(int32)", n) + } + + err := w.WriteSet(Set{ + Size: int32(n), + Type: typ, + }) + if err != nil { + return err + } + if n == 0 { // empty map + return nil + } + + for i, iter := 0, v.MapRange(); iter.Next(); i++ { + if err := enc(w, iter.Key(), flags); err != nil { + return err + } + } + + return nil + } +} + +type structEncoder struct { + fields []structEncoderField + union bool +} + +func (enc *structEncoder) encode(w Writer, v reflect.Value, flags flags) error { + useDeltaEncoding := flags.have(useDeltaEncoding) + coalesceBoolFields := flags.have(coalesceBoolFields) + numFields := int16(0) + lastFieldID := int16(0) + +encodeFields: + for _, f := range enc.fields { + x := v + for _, i := range f.index { + if x.Kind() == reflect.Ptr { + x = x.Elem() + } + if x = x.Field(i); x.Kind() == reflect.Ptr { + if x.IsNil() { + continue encodeFields + } + } + } + + if !f.flags.have(required) && x.IsZero() { + continue encodeFields + } + + field := Field{ + ID: f.id, + Type: f.typ, + } + + if useDeltaEncoding { + if delta := field.ID - lastFieldID; delta <= 15 { + field.ID = delta + field.Delta = true + } + } + + skipValue := coalesceBoolFields && field.Type == BOOL + if skipValue && x.Bool() == true { + field.Type = TRUE + } + + if err := w.WriteField(field); err != nil { + return err + } + + if !skipValue { + if err := f.encode(w, x, flags); err != nil { + return err + } + } + + numFields++ + lastFieldID = f.id + } + + if err := w.WriteField(Field{Type: STOP}); err != nil { + return err + } + + if numFields > 1 && enc.union { + return fmt.Errorf("thrift union had more than one field with a non-zero value (%d)", numFields) + } + + return nil +} + +func (enc *structEncoder) String() string { + if enc.union { + return "union" + } + return "struct" +} + +type structEncoderField struct { + index []int + id int16 + flags flags + typ Type + encode encodeFunc +} + +func encodeFuncStructOf(t reflect.Type, seen encodeFuncCache) encodeFunc { + enc := &structEncoder{ + fields: make([]structEncoderField, 0, t.NumField()), + } + encode := enc.encode + seen[t] = encode + + forEachStructField(t, nil, func(f structField) { + if f.flags.have(union) { + enc.union = true + } else { + enc.fields = append(enc.fields, structEncoderField{ + index: f.index, + id: f.id, + flags: f.flags, + typ: TypeOf(f.typ), + encode: encodeFuncStructFieldOf(f, seen), + }) + } + }) + + sort.SliceStable(enc.fields, func(i, j int) bool { + return enc.fields[i].id < enc.fields[j].id + }) + + for i := len(enc.fields) - 1; i > 0; i-- { + if enc.fields[i-1].id == enc.fields[i].id { + panic(fmt.Errorf("thrift struct field id %d is present multiple times", enc.fields[i].id)) + } + } + + return encode +} + +func encodeFuncStructFieldOf(f structField, seen encodeFuncCache) encodeFunc { + if f.flags.have(enum) { + switch f.typ.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return encodeInt32 + } + } + return encodeFuncOf(f.typ, seen) +} + +func encodeFuncPtrOf(t reflect.Type, seen encodeFuncCache) encodeFunc { + typ := t.Elem() + enc := encodeFuncOf(typ, seen) + zero := reflect.Zero(typ) + + return func(w Writer, v reflect.Value, f flags) error { + if v.IsNil() { + v = zero + } + return enc(w, v, f) + } +} diff --git a/thrift/error.go b/thrift/error.go new file mode 100644 index 0000000..ceeb1ba --- /dev/null +++ b/thrift/error.go @@ -0,0 +1,111 @@ +package thrift + +import ( + "errors" + "fmt" + "io" + "strings" +) + +type MissingField struct { + Field Field +} + +func (e *MissingField) Error() string { + return fmt.Sprintf("missing required field: %s", e.Field) +} + +type TypeMismatch struct { + Expect Type + Found Type + item string +} + +func (e *TypeMismatch) Error() string { + return fmt.Sprintf("%s type mismatch: expected %s but found %s", e.item, e.Expect, e.Found) +} + +type decodeError struct { + base error + path []error +} + +func (e *decodeError) Error() string { + s := strings.Builder{} + s.Grow(256) + s.WriteString("decoding thrift payload: ") + + if len(e.path) != 0 { + n := len(e.path) - 1 + for i := n; i >= 0; i-- { + if i < n { + s.WriteString(" → ") + } + s.WriteString(e.path[i].Error()) + } + s.WriteString(": ") + } + + s.WriteString(e.base.Error()) + return s.String() +} + +func (e *decodeError) Unwrap() error { return e.base } + +func with(base, elem error) error { + if errors.Is(base, io.EOF) { + return base + } + e, _ := base.(*decodeError) + if e == nil { + e = &decodeError{base: base} + } + e.path = append(e.path, elem) + return e +} + +type decodeErrorField struct { + cause Field +} + +func (d *decodeErrorField) Error() string { + return d.cause.String() +} + +type decodeErrorList struct { + cause List + index int +} + +func (d *decodeErrorList) Error() string { + return fmt.Sprintf("%d/%d:%s", d.index, d.cause.Size, d.cause) +} + +type decodeErrorSet struct { + cause Set + index int +} + +func (d *decodeErrorSet) Error() string { + return fmt.Sprintf("%d/%d:%s", d.index, d.cause.Size, d.cause) +} + +type decodeErrorMap struct { + cause Map + index int +} + +func (d *decodeErrorMap) Error() string { + return fmt.Sprintf("%d/%d:%s", d.index, d.cause.Size, d.cause) +} + +func dontExpectEOF(err error) error { + switch err { + case nil: + return nil + case io.EOF: + return io.ErrUnexpectedEOF + default: + return err + } +} diff --git a/thrift/protocol.go b/thrift/protocol.go new file mode 100644 index 0000000..7c31338 --- /dev/null +++ b/thrift/protocol.go @@ -0,0 +1,73 @@ +package thrift + +import ( + "io" +) + +// Features is a bitset describing the thrift encoding features supported by +// protocol implementations. +type Features uint + +const ( + // DeltaEncoding is advertised by protocols that allow encoders to apply + // delta encoding on struct fields. + UseDeltaEncoding Features = 1 << iota + + // CoalesceBoolFields is advertised by protocols that allow encoders to + // coalesce boolean values into field types. + CoalesceBoolFields +) + +// The Protocol interface abstracts the creation of low-level thrift readers and +// writers implementing the various protocols that the encoding supports. +// +// Protocol instances must be safe to use concurrently from multiple gourintes. +// However, the readers and writer that they instantiates are intended to be +// used by a single goroutine. +type Protocol interface { + NewReader(r io.Reader) Reader + NewWriter(w io.Writer) Writer + Features() Features +} + +// Reader represents a low-level reader of values encoded according to one of +// the thrift protocols. +type Reader interface { + Protocol() Protocol + Reader() io.Reader + ReadBool() (bool, error) + ReadInt8() (int8, error) + ReadInt16() (int16, error) + ReadInt32() (int32, error) + ReadInt64() (int64, error) + ReadFloat64() (float64, error) + ReadBytes() ([]byte, error) + ReadString() (string, error) + ReadLength() (int, error) + ReadMessage() (Message, error) + ReadField() (Field, error) + ReadList() (List, error) + ReadSet() (Set, error) + ReadMap() (Map, error) +} + +// Writer represents a low-level writer of values encoded according to one of +// the thrift protocols. +type Writer interface { + Protocol() Protocol + Writer() io.Writer + WriteBool(bool) error + WriteInt8(int8) error + WriteInt16(int16) error + WriteInt32(int32) error + WriteInt64(int64) error + WriteFloat64(float64) error + WriteBytes([]byte) error + WriteString(string) error + WriteLength(int) error + WriteMessage(Message) error + WriteField(Field) error + WriteList(List) error + WriteSet(Set) error + WriteMap(Map) error +} diff --git a/thrift/protocol_test.go b/thrift/protocol_test.go new file mode 100644 index 0000000..0bdd38e --- /dev/null +++ b/thrift/protocol_test.go @@ -0,0 +1,204 @@ +package thrift_test + +import ( + "bytes" + "reflect" + "strings" + "testing" + + "github.com/segmentio/encoding/thrift" +) + +var protocolReadWriteTests = [...]struct { + scenario string + read interface{} + write interface{} + values []interface{} +}{ + { + scenario: "bool", + read: thrift.Reader.ReadBool, + write: thrift.Writer.WriteBool, + values: []interface{}{false, true}, + }, + + { + scenario: "int8", + read: thrift.Reader.ReadInt8, + write: thrift.Writer.WriteInt8, + values: []interface{}{int8(0), int8(1), int8(-1)}, + }, + + { + scenario: "int16", + read: thrift.Reader.ReadInt16, + write: thrift.Writer.WriteInt16, + values: []interface{}{int16(0), int16(1), int16(-1)}, + }, + + { + scenario: "int32", + read: thrift.Reader.ReadInt32, + write: thrift.Writer.WriteInt32, + values: []interface{}{int32(0), int32(1), int32(-1)}, + }, + + { + scenario: "int64", + read: thrift.Reader.ReadInt64, + write: thrift.Writer.WriteInt64, + values: []interface{}{int64(0), int64(1), int64(-1)}, + }, + + { + scenario: "float64", + read: thrift.Reader.ReadFloat64, + write: thrift.Writer.WriteFloat64, + values: []interface{}{float64(0), float64(1), float64(-1)}, + }, + + { + scenario: "bytes", + read: thrift.Reader.ReadBytes, + write: thrift.Writer.WriteBytes, + values: []interface{}{ + []byte(""), + []byte("A"), + []byte("1234567890"), + bytes.Repeat([]byte("qwertyuiop"), 100), + }, + }, + + { + scenario: "string", + read: thrift.Reader.ReadString, + write: thrift.Writer.WriteString, + values: []interface{}{ + "", + "A", + "1234567890", + strings.Repeat("qwertyuiop", 100), + }, + }, + + { + scenario: "message", + read: thrift.Reader.ReadMessage, + write: thrift.Writer.WriteMessage, + values: []interface{}{ + thrift.Message{}, + thrift.Message{Type: thrift.Call, Name: "Hello", SeqID: 10}, + thrift.Message{Type: thrift.Reply, Name: "World", SeqID: 11}, + thrift.Message{Type: thrift.Exception, Name: "Foo", SeqID: 40}, + thrift.Message{Type: thrift.Oneway, Name: "Bar", SeqID: 42}, + }, + }, + + { + scenario: "field", + read: thrift.Reader.ReadField, + write: thrift.Writer.WriteField, + values: []interface{}{ + thrift.Field{ID: 101, Type: thrift.TRUE}, + thrift.Field{ID: 102, Type: thrift.FALSE}, + thrift.Field{ID: 103, Type: thrift.I8}, + thrift.Field{ID: 104, Type: thrift.I16}, + thrift.Field{ID: 105, Type: thrift.I32}, + thrift.Field{ID: 106, Type: thrift.I64}, + thrift.Field{ID: 107, Type: thrift.DOUBLE}, + thrift.Field{ID: 108, Type: thrift.BINARY}, + thrift.Field{ID: 109, Type: thrift.LIST}, + thrift.Field{ID: 110, Type: thrift.SET}, + thrift.Field{ID: 111, Type: thrift.MAP}, + thrift.Field{ID: 112, Type: thrift.STRUCT}, + thrift.Field{}, + }, + }, + + { + scenario: "list", + read: thrift.Reader.ReadList, + write: thrift.Writer.WriteList, + values: []interface{}{ + thrift.List{}, + thrift.List{Size: 0, Type: thrift.BOOL}, + thrift.List{Size: 1, Type: thrift.I8}, + thrift.List{Size: 1000, Type: thrift.BINARY}, + }, + }, + + { + scenario: "map", + read: thrift.Reader.ReadMap, + write: thrift.Writer.WriteMap, + values: []interface{}{ + thrift.Map{}, + thrift.Map{Size: 1, Key: thrift.BINARY, Value: thrift.MAP}, + thrift.Map{Size: 1000, Key: thrift.BINARY, Value: thrift.LIST}, + }, + }, +} + +var protocols = [...]struct { + name string + proto thrift.Protocol +}{ + { + name: "binary(default)", + proto: &thrift.BinaryProtocol{}, + }, + + { + name: "binary(non-strict)", + proto: &thrift.BinaryProtocol{ + NonStrict: true, + }, + }, + + { + name: "compact", + proto: &thrift.CompactProtocol{}, + }, +} + +func TestProtocols(t *testing.T) { + for _, test := range protocols { + t.Run(test.name, func(t *testing.T) { testProtocolReadWriteValues(t, test.proto) }) + } +} + +func testProtocolReadWriteValues(t *testing.T, p thrift.Protocol) { + for _, test := range protocolReadWriteTests { + t.Run(test.scenario, func(t *testing.T) { + b := new(bytes.Buffer) + r := p.NewReader(b) + w := p.NewWriter(b) + + for _, value := range test.values { + ret := reflect.ValueOf(test.write).Call([]reflect.Value{ + reflect.ValueOf(w), + reflect.ValueOf(value), + }) + if err, _ := ret[0].Interface().(error); err != nil { + t.Fatal("encoding:", err) + } + } + + for _, value := range test.values { + ret := reflect.ValueOf(test.read).Call([]reflect.Value{ + reflect.ValueOf(r), + }) + if err, _ := ret[1].Interface().(error); err != nil { + t.Fatal("decoding:", err) + } + if res := ret[0].Interface(); !reflect.DeepEqual(value, res) { + t.Errorf("value mismatch:\nwant: %#v\ngot: %#v", value, res) + } + } + + if b.Len() != 0 { + t.Errorf("unexpected trailing bytes: %d", b.Len()) + } + }) + } +} diff --git a/thrift/struct.go b/thrift/struct.go new file mode 100644 index 0000000..4bc61eb --- /dev/null +++ b/thrift/struct.go @@ -0,0 +1,132 @@ +package thrift + +import ( + "fmt" + "reflect" + "strconv" + "strings" +) + +type flags int16 + +const ( + enum flags = 1 << 0 + union flags = 1 << 1 + required flags = 1 << 2 + optional flags = 1 << 3 + strict flags = 1 << 4 + + featuresBitOffset = 8 + useDeltaEncoding = flags(UseDeltaEncoding) << featuresBitOffset + coalesceBoolFields = flags(CoalesceBoolFields) << featuresBitOffset + + structFlags flags = enum | union | required | optional + encodeFlags flags = strict | protocolFlags + decodeFlags flags = strict | protocolFlags + protocolFlags flags = useDeltaEncoding | coalesceBoolFields +) + +func (f flags) have(x flags) bool { + return (f & x) == x +} + +func (f flags) only(x flags) flags { + return f & x +} + +func (f flags) with(x flags) flags { + return f | x +} + +func (f flags) without(x flags) flags { + return f & ^x +} + +type structField struct { + typ reflect.Type + index []int + id int16 + flags flags +} + +func forEachStructField(t reflect.Type, index []int, do func(structField)) { + for i, n := 0, t.NumField(); i < n; i++ { + f := t.Field(i) + + if f.PkgPath != "" && !f.Anonymous { // unexported + continue + } + + fieldIndex := append(index, i) + fieldIndex = fieldIndex[:len(fieldIndex):len(fieldIndex)] + + if f.Anonymous && f.Type.Kind() == reflect.Struct { + forEachStructField(f.Type, fieldIndex, do) + continue + } + + tag := f.Tag.Get("thrift") + if tag == "" { + continue + } + tags := strings.Split(tag, ",") + flags := flags(0) + + for _, opt := range tags[1:] { + switch opt { + case "enum": + flags = flags.with(enum) + case "union": + flags = flags.with(union) + case "required": + flags = flags.with(required) + case "optional": + flags = flags.with(optional) + default: + panic(fmt.Errorf("thrift struct field contains an unknown tag option %q in `thrift:\"%s\"`", opt, tag)) + } + } + + if flags.have(optional | required) { + panic(fmt.Errorf("thrift struct field cannot be both optional and required in `thrift:\"%s\"`", tag)) + } + + if flags.have(union) { + if f.Type.Kind() != reflect.Interface { + panic(fmt.Errorf("thrift union tag found on a field which is not an interface type `thrift:\"%s\"`", tag)) + } + + if tags[0] != "" { + panic(fmt.Errorf("invalid thrift field id on union field `thrift:\"%s\"`", tag)) + } + + do(structField{ + typ: f.Type, + index: fieldIndex, + flags: flags, + }) + } else { + if flags.have(enum) { + switch f.Type.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + default: + panic(fmt.Errorf("thrift enum tag found on a field which is not an integer type `thrift:\"%s\"`", tag)) + } + } + + if id, err := strconv.ParseInt(tags[0], 10, 16); err != nil { + panic(fmt.Errorf("invalid thrift field id found in struct tag `thrift:\"%s\"`: %w", tag, err)) + } else if id <= 0 { + panic(fmt.Errorf("invalid thrift field id found in struct tag `thrift:\"%s\"`: %d <= 0", tag, id)) + } else { + do(structField{ + typ: f.Type, + index: fieldIndex, + id: int16(id), + flags: flags, + }) + } + } + } +} diff --git a/thrift/thrift.go b/thrift/thrift.go new file mode 100644 index 0000000..b3682e4 --- /dev/null +++ b/thrift/thrift.go @@ -0,0 +1,164 @@ +package thrift + +import ( + "fmt" + "reflect" +) + +type Message struct { + Type MessageType + Name string + SeqID int32 +} + +type MessageType int8 + +const ( + Call MessageType = iota + Reply + Exception + Oneway +) + +func (m MessageType) String() string { + switch m { + case Call: + return "Call" + case Reply: + return "Reply" + case Exception: + return "Exception" + case Oneway: + return "Oneway" + default: + return "?" + } +} + +type Field struct { + ID int16 + Type Type + Delta bool // whether the field id is a delta +} + +func (f Field) String() string { + return fmt.Sprintf("%d:FIELD<%s>", f.ID, f.Type) +} + +type Type int8 + +const ( + STOP Type = iota + TRUE + FALSE + I8 + I16 + I32 + I64 + DOUBLE + BINARY + LIST + SET + MAP + STRUCT + BOOL = FALSE +) + +func (t Type) String() string { + switch t { + case STOP: + return "STOP" + case TRUE: + return "TRUE" + case BOOL: + return "BOOL" + case I8: + return "I8" + case I16: + return "I16" + case I32: + return "I32" + case I64: + return "I64" + case DOUBLE: + return "DOUBLE" + case BINARY: + return "BINARY" + case LIST: + return "LIST" + case SET: + return "SET" + case MAP: + return "MAP" + case STRUCT: + return "STRUCT" + default: + return "?" + } +} + +func (t Type) GoString() string { + return "thrift." + t.String() +} + +type List struct { + Size int32 + Type Type +} + +func (l List) String() string { + return fmt.Sprintf("LIST<%s>", l.Type) +} + +type Set List + +func (s Set) String() string { + return fmt.Sprintf("SET<%s>", s.Type) +} + +type Map struct { + Size int32 + Key Type + Value Type +} + +func (m Map) String() string { + return fmt.Sprintf("MAP<%s,%s>", m.Key, m.Value) +} + +func TypeOf(t reflect.Type) Type { + switch t.Kind() { + case reflect.Bool: + return BOOL + case reflect.Int8, reflect.Uint8: + return I8 + case reflect.Int16, reflect.Uint16: + return I16 + case reflect.Int32, reflect.Uint32: + return I32 + case reflect.Int64, reflect.Uint64, reflect.Int, reflect.Uint, reflect.Uintptr: + return I64 + case reflect.Float32, reflect.Float64: + return DOUBLE + case reflect.String: + return BINARY + case reflect.Slice: + if t.Elem().Kind() == reflect.Uint8 { // []byte + return BINARY + } else { + return LIST + } + case reflect.Map: + if t.Elem().Size() == 0 { + return SET + } else { + return MAP + } + case reflect.Struct: + return STRUCT + case reflect.Ptr: + return TypeOf(t.Elem()) + default: + panic("type cannot be represented in thrift: " + t.String()) + } +} diff --git a/thrift/thrift_test.go b/thrift/thrift_test.go new file mode 100644 index 0000000..e8f46e8 --- /dev/null +++ b/thrift/thrift_test.go @@ -0,0 +1,298 @@ +package thrift_test + +import ( + "bytes" + "math" + "reflect" + "strings" + "testing" + + "github.com/segmentio/encoding/thrift" +) + +var marshalTestValues = [...]struct { + scenario string + values []interface{} +}{ + { + scenario: "bool", + values: []interface{}{false, true}, + }, + + { + scenario: "int", + values: []interface{}{ + int(0), + int(-1), + int(1), + }, + }, + + { + scenario: "int8", + values: []interface{}{ + int8(0), + int8(-1), + int8(1), + int8(math.MinInt8), + int8(math.MaxInt8), + }, + }, + + { + scenario: "int16", + values: []interface{}{ + int16(0), + int16(-1), + int16(1), + int16(math.MinInt16), + int16(math.MaxInt16), + }, + }, + + { + scenario: "int32", + values: []interface{}{ + int32(0), + int32(-1), + int32(1), + int32(math.MinInt32), + int32(math.MaxInt32), + }, + }, + + { + scenario: "int64", + values: []interface{}{ + int64(0), + int64(-1), + int64(1), + int64(math.MinInt64), + int64(math.MaxInt64), + }, + }, + + { + scenario: "string", + values: []interface{}{ + "", + "A", + "1234567890", + strings.Repeat("qwertyuiop", 100), + }, + }, + + { + scenario: "[]byte", + values: []interface{}{ + []byte(""), + []byte("A"), + []byte("1234567890"), + bytes.Repeat([]byte("qwertyuiop"), 100), + }, + }, + + { + scenario: "[]string", + values: []interface{}{ + []string{}, + []string{"A"}, + []string{"hello", "world", "!!!"}, + []string{"0", "1", "3", "4", "5", "6", "7", "8", "9"}, + }, + }, + + { + scenario: "map[string]int", + values: []interface{}{ + map[string]int{}, + map[string]int{"A": 1}, + map[string]int{"hello": 1, "world": 2, "answer": 42}, + }, + }, + + { + scenario: "map[int64]struct{}", + values: []interface{}{ + map[int64]struct{}{}, + map[int64]struct{}{0: {}, 1: {}, 2: {}}, + }, + }, + + { + scenario: "[]map[string]struct{}", + values: []interface{}{ + []map[string]struct{}{}, + []map[string]struct{}{{}, {"A": {}, "B": {}, "C": {}}}, + }, + }, + + { + scenario: "struct{}", + values: []interface{}{struct{}{}}, + }, + + { + scenario: "Point2D", + values: []interface{}{ + Point2D{}, + Point2D{X: 1}, + Point2D{Y: 2}, + Point2D{X: 3, Y: 4}, + }, + }, + + { + scenario: "RecursiveStruct", + values: []interface{}{ + RecursiveStruct{}, + RecursiveStruct{Value: "hello"}, + RecursiveStruct{Value: "hello", Next: &RecursiveStruct{}}, + RecursiveStruct{Value: "hello", Next: &RecursiveStruct{Value: "world"}}, + }, + }, + + { + scenario: "StructWithEnum", + values: []interface{}{ + StructWithEnum{}, + StructWithEnum{Enum: 1}, + StructWithEnum{Enum: 2}, + }, + }, + + { + scenario: "Union", + values: []interface{}{ + Union{}, + Union{A: true, F: newBool(true)}, + Union{B: 42, F: newInt(42)}, + Union{C: "hello world!", F: newString("hello world!")}, + }, + }, +} + +type Point2D struct { + X float64 `thrift:"1,required"` + Y float64 `thrift:"2,required"` +} + +type RecursiveStruct struct { + Value string `thrift:"1"` + Next *RecursiveStruct `thrift:"2"` +} + +type StructWithEnum struct { + Enum int8 `thrift:"1,enum"` +} + +type Union struct { + A bool `thrift:"1"` + B int `thrift:"2"` + C string `thrift:"3"` + F interface{} `thrift:",union"` +} + +func newBool(b bool) *bool { return &b } +func newInt(i int) *int { return &i } +func newString(s string) *string { return &s } + +func TestMarshalUnmarshal(t *testing.T) { + for _, p := range protocols { + t.Run(p.name, func(t *testing.T) { testMarshalUnmarshal(t, p.proto) }) + } +} + +func testMarshalUnmarshal(t *testing.T, p thrift.Protocol) { + for _, test := range marshalTestValues { + t.Run(test.scenario, func(t *testing.T) { + for _, value := range test.values { + b, err := thrift.Marshal(p, value) + if err != nil { + t.Fatal("marshal:", err) + } + + v := reflect.New(reflect.TypeOf(value)) + if err := thrift.Unmarshal(p, b, v.Interface()); err != nil { + t.Fatal("unmarshal:", err) + } + + if result := v.Elem().Interface(); !reflect.DeepEqual(value, result) { + t.Errorf("value mismatch:\nwant: %#v\ngot: %#v", value, result) + } + } + }) + } +} + +func BenchmarkMarshal(b *testing.B) { + for _, p := range protocols { + b.Run(p.name, func(b *testing.B) { benchmarkMarshal(b, p.proto) }) + } +} + +type BenchmarkEncodeType struct { + Name string `thrift:"1"` + Question string `thrift:"2"` + Answer string `thrift:"3"` + Sub *BenchmarkEncodeType `thrift:"4"` +} + +func benchmarkMarshal(b *testing.B, p thrift.Protocol) { + buf := new(bytes.Buffer) + enc := thrift.NewEncoder(p.NewWriter(buf)) + val := &BenchmarkEncodeType{ + Name: "Luke", + Question: "How are you?", + Answer: "42", + Sub: &BenchmarkEncodeType{ + Name: "Leia", + Question: "?", + Answer: "whatever", + }, + } + + for i := 0; i < b.N; i++ { + buf.Reset() + enc.Encode(val) + } + + b.SetBytes(int64(buf.Len())) +} + +func BenchmarkUnmarshal(b *testing.B) { + for _, p := range protocols { + b.Run(p.name, func(b *testing.B) { benchmarkUnmarshal(b, p.proto) }) + } +} + +type BenchmarkDecodeType struct { + Name string `thrift:"1"` + Question string `thrift:"2"` + Answer string `thrift:"3"` + Sub *BenchmarkDecodeType `thrift:"4"` +} + +func benchmarkUnmarshal(b *testing.B, p thrift.Protocol) { + buf, _ := thrift.Marshal(p, &BenchmarkDecodeType{ + Name: "Luke", + Question: "How are you?", + Answer: "42", + Sub: &BenchmarkDecodeType{ + Name: "Leia", + Question: "?", + Answer: "whatever", + }, + }) + + rb := bytes.NewReader(nil) + dec := thrift.NewDecoder(p.NewReader(rb)) + val := &BenchmarkDecodeType{} + + for i := 0; i < b.N; i++ { + rb.Reset(buf) + dec.Decode(val) + } + + b.SetBytes(int64(len(buf))) +} diff --git a/thrift/unsafe.go b/thrift/unsafe.go new file mode 100644 index 0000000..9572b40 --- /dev/null +++ b/thrift/unsafe.go @@ -0,0 +1,24 @@ +package thrift + +import ( + "reflect" + "unsafe" +) + +// typeID is used as key in encoder and decoder caches to enable using +// the optimize runtime.mapaccess2_fast64 function instead of the more +// expensive lookup if we were to use reflect.Type as map key. +// +// typeID holds the pointer to the reflect.Type value, which is unique +// in the program. +type typeID struct{ ptr unsafe.Pointer } + +func makeTypeID(t reflect.Type) typeID { + return typeID{ + ptr: (*[2]unsafe.Pointer)(unsafe.Pointer(&t))[1], + } +} + +func unsafeBytesToString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +}