-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add thrift package * add encoder and decoder APIs * add a bit of documentation * add enum struct field tag * add a test for the enum struct type * use ioutil.Discard to be compatible with Go 1.14 * add support for union types * don't use IsExported to be compatible with older Go versions * add support for required fields * improve union model * improve error reporting * improve error reporting (2) * add a strict mode to the decoder * add debug reader and writer * use reflect.Value.IsZero * simplify API + fix use of non-zigzag varint encoding * simplify and optimize reading of BINARY type * fixes * test that io.EOF is returned when decoding an empty input * optimize required fields validation * cleanup * fix missing field index * consider anonymous embedded structs * rename error fields to 'cause' * remove magic support for unsigned integer values
- Loading branch information
Achille
authored
Nov 9, 2021
1 parent
cc0871e
commit 47bdac2
Showing
14 changed files
with
3,015 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,367 @@ | ||
package thrift | ||
|
||
import ( | ||
"bufio" | ||
"bytes" | ||
"encoding/binary" | ||
"fmt" | ||
"io" | ||
"math" | ||
) | ||
|
||
// BinaryProtocol is a Protocol implementation for the binary thrift protocol. | ||
// | ||
// https://github.com/apache/thrift/blob/master/doc/specs/thrift-binary-protocol.md | ||
type BinaryProtocol struct { | ||
NonStrict bool | ||
} | ||
|
||
func (p *BinaryProtocol) NewReader(r io.Reader) Reader { | ||
return &binaryReader{p: p, r: r} | ||
} | ||
|
||
func (p *BinaryProtocol) NewWriter(w io.Writer) Writer { | ||
return &binaryWriter{p: p, w: w} | ||
} | ||
|
||
func (p *BinaryProtocol) Features() Features { | ||
return 0 | ||
} | ||
|
||
type binaryReader struct { | ||
p *BinaryProtocol | ||
r io.Reader | ||
b [8]byte | ||
} | ||
|
||
func (r *binaryReader) Protocol() Protocol { | ||
return r.p | ||
} | ||
|
||
func (r *binaryReader) Reader() io.Reader { | ||
return r.r | ||
} | ||
|
||
func (r *binaryReader) ReadBool() (bool, error) { | ||
v, err := r.ReadByte() | ||
return v != 0, err | ||
} | ||
|
||
func (r *binaryReader) ReadInt8() (int8, error) { | ||
b, err := r.ReadByte() | ||
return int8(b), err | ||
} | ||
|
||
func (r *binaryReader) ReadInt16() (int16, error) { | ||
b, err := r.read(2) | ||
if len(b) < 2 { | ||
return 0, err | ||
} | ||
return int16(binary.BigEndian.Uint16(b)), nil | ||
} | ||
|
||
func (r *binaryReader) ReadInt32() (int32, error) { | ||
b, err := r.read(4) | ||
if len(b) < 4 { | ||
return 0, err | ||
} | ||
return int32(binary.BigEndian.Uint32(b)), nil | ||
} | ||
|
||
func (r *binaryReader) ReadInt64() (int64, error) { | ||
b, err := r.read(8) | ||
if len(b) < 8 { | ||
return 0, err | ||
} | ||
return int64(binary.BigEndian.Uint64(b)), nil | ||
} | ||
|
||
func (r *binaryReader) ReadFloat64() (float64, error) { | ||
b, err := r.read(8) | ||
if len(b) < 8 { | ||
return 0, err | ||
} | ||
return math.Float64frombits(binary.BigEndian.Uint64(b)), nil | ||
} | ||
|
||
func (r *binaryReader) ReadBytes() ([]byte, error) { | ||
n, err := r.ReadLength() | ||
if err != nil { | ||
return nil, err | ||
} | ||
b := make([]byte, n) | ||
_, err = io.ReadFull(r.r, b) | ||
return b, err | ||
} | ||
|
||
func (r *binaryReader) ReadString() (string, error) { | ||
b, err := r.ReadBytes() | ||
return unsafeBytesToString(b), err | ||
} | ||
|
||
func (r *binaryReader) ReadLength() (int, error) { | ||
b, err := r.read(4) | ||
if len(b) < 4 { | ||
return 0, err | ||
} | ||
n := binary.BigEndian.Uint32(b) | ||
if n > math.MaxInt32 { | ||
return 0, fmt.Errorf("length out of range: %d", n) | ||
} | ||
return int(n), nil | ||
} | ||
|
||
func (r *binaryReader) ReadMessage() (Message, error) { | ||
m := Message{} | ||
|
||
b, err := r.read(4) | ||
if len(b) < 4 { | ||
return m, err | ||
} | ||
|
||
if (b[0] >> 7) == 0 { // non-strict | ||
n := int(binary.BigEndian.Uint32(b)) | ||
s := make([]byte, n) | ||
_, err := io.ReadFull(r.r, s) | ||
if err != nil { | ||
return m, dontExpectEOF(err) | ||
} | ||
m.Name = unsafeBytesToString(s) | ||
|
||
t, err := r.ReadInt8() | ||
if err != nil { | ||
return m, dontExpectEOF(err) | ||
} | ||
|
||
m.Type = MessageType(t & 0x7) | ||
} else { | ||
m.Type = MessageType(b[3] & 0x7) | ||
|
||
if m.Name, err = r.ReadString(); err != nil { | ||
return m, dontExpectEOF(err) | ||
} | ||
} | ||
|
||
m.SeqID, err = r.ReadInt32() | ||
return m, err | ||
} | ||
|
||
func (r *binaryReader) ReadField() (Field, error) { | ||
t, err := r.ReadInt8() | ||
if err != nil { | ||
return Field{}, err | ||
} | ||
i, err := r.ReadInt16() | ||
if err != nil { | ||
return Field{}, err | ||
} | ||
return Field{ID: i, Type: Type(t)}, nil | ||
} | ||
|
||
func (r *binaryReader) ReadList() (List, error) { | ||
t, err := r.ReadInt8() | ||
if err != nil { | ||
return List{}, err | ||
} | ||
n, err := r.ReadInt32() | ||
if err != nil { | ||
return List{}, dontExpectEOF(err) | ||
} | ||
return List{Size: n, Type: Type(t)}, nil | ||
} | ||
|
||
func (r *binaryReader) ReadSet() (Set, error) { | ||
l, err := r.ReadList() | ||
return Set(l), err | ||
} | ||
|
||
func (r *binaryReader) ReadMap() (Map, error) { | ||
k, err := r.ReadByte() | ||
if err != nil { | ||
return Map{}, err | ||
} | ||
v, err := r.ReadByte() | ||
if err != nil { | ||
return Map{}, dontExpectEOF(err) | ||
} | ||
n, err := r.ReadInt32() | ||
if err != nil { | ||
return Map{}, dontExpectEOF(err) | ||
} | ||
return Map{Size: n, Key: Type(k), Value: Type(v)}, nil | ||
} | ||
|
||
func (r *binaryReader) ReadByte() (byte, error) { | ||
switch x := r.r.(type) { | ||
case *bytes.Buffer: | ||
return x.ReadByte() | ||
case *bytes.Reader: | ||
return x.ReadByte() | ||
case *bufio.Reader: | ||
return x.ReadByte() | ||
case io.ByteReader: | ||
return x.ReadByte() | ||
default: | ||
b, err := r.read(1) | ||
if err != nil { | ||
return 0, err | ||
} | ||
return b[0], nil | ||
} | ||
} | ||
|
||
func (r *binaryReader) read(n int) ([]byte, error) { | ||
_, err := io.ReadFull(r.r, r.b[:n]) | ||
return r.b[:n], err | ||
} | ||
|
||
type binaryWriter struct { | ||
p *BinaryProtocol | ||
b [8]byte | ||
w io.Writer | ||
} | ||
|
||
func (w *binaryWriter) Protocol() Protocol { | ||
return w.p | ||
} | ||
|
||
func (w *binaryWriter) Writer() io.Writer { | ||
return w.w | ||
} | ||
|
||
func (w *binaryWriter) WriteBool(v bool) error { | ||
var b byte | ||
if v { | ||
b = 1 | ||
} | ||
return w.writeByte(b) | ||
} | ||
|
||
func (w *binaryWriter) WriteInt8(v int8) error { | ||
return w.writeByte(byte(v)) | ||
} | ||
|
||
func (w *binaryWriter) WriteInt16(v int16) error { | ||
binary.BigEndian.PutUint16(w.b[:2], uint16(v)) | ||
return w.write(w.b[:2]) | ||
} | ||
|
||
func (w *binaryWriter) WriteInt32(v int32) error { | ||
binary.BigEndian.PutUint32(w.b[:4], uint32(v)) | ||
return w.write(w.b[:4]) | ||
} | ||
|
||
func (w *binaryWriter) WriteInt64(v int64) error { | ||
binary.BigEndian.PutUint64(w.b[:8], uint64(v)) | ||
return w.write(w.b[:8]) | ||
} | ||
|
||
func (w *binaryWriter) WriteFloat64(v float64) error { | ||
binary.BigEndian.PutUint64(w.b[:8], math.Float64bits(v)) | ||
return w.write(w.b[:8]) | ||
} | ||
|
||
func (w *binaryWriter) WriteBytes(v []byte) error { | ||
if err := w.WriteLength(len(v)); err != nil { | ||
return err | ||
} | ||
return w.write(v) | ||
} | ||
|
||
func (w *binaryWriter) WriteString(v string) error { | ||
if err := w.WriteLength(len(v)); err != nil { | ||
return err | ||
} | ||
return w.writeString(v) | ||
} | ||
|
||
func (w *binaryWriter) WriteLength(n int) error { | ||
if n < 0 { | ||
return fmt.Errorf("negative length cannot be encoded in thrift: %d", n) | ||
} | ||
if n > math.MaxInt32 { | ||
return fmt.Errorf("length is too large to be encoded in thrift: %d", n) | ||
} | ||
return w.WriteInt32(int32(n)) | ||
} | ||
|
||
func (w *binaryWriter) WriteMessage(m Message) error { | ||
if w.p.NonStrict { | ||
if err := w.WriteString(m.Name); err != nil { | ||
return err | ||
} | ||
if err := w.writeByte(byte(m.Type)); err != nil { | ||
return err | ||
} | ||
} else { | ||
w.b[0] = 1 << 7 | ||
w.b[1] = 0 | ||
w.b[2] = 0 | ||
w.b[3] = byte(m.Type) & 0x7 | ||
binary.BigEndian.PutUint32(w.b[4:], uint32(len(m.Name))) | ||
|
||
if err := w.write(w.b[:8]); err != nil { | ||
return err | ||
} | ||
if err := w.writeString(m.Name); err != nil { | ||
return err | ||
} | ||
} | ||
return w.WriteInt32(m.SeqID) | ||
} | ||
|
||
func (w *binaryWriter) WriteField(f Field) error { | ||
if err := w.writeByte(byte(f.Type)); err != nil { | ||
return err | ||
} | ||
return w.WriteInt16(f.ID) | ||
} | ||
|
||
func (w *binaryWriter) WriteList(l List) error { | ||
if err := w.writeByte(byte(l.Type)); err != nil { | ||
return err | ||
} | ||
return w.WriteInt32(l.Size) | ||
} | ||
|
||
func (w *binaryWriter) WriteSet(s Set) error { | ||
return w.WriteList(List(s)) | ||
} | ||
|
||
func (w *binaryWriter) WriteMap(m Map) error { | ||
if err := w.writeByte(byte(m.Key)); err != nil { | ||
return err | ||
} | ||
if err := w.writeByte(byte(m.Value)); err != nil { | ||
return err | ||
} | ||
return w.WriteInt32(m.Size) | ||
} | ||
|
||
func (w *binaryWriter) write(b []byte) error { | ||
_, err := w.w.Write(b) | ||
return err | ||
} | ||
|
||
func (w *binaryWriter) writeString(s string) error { | ||
_, err := io.WriteString(w.w, s) | ||
return err | ||
} | ||
|
||
func (w *binaryWriter) writeByte(b byte) error { | ||
// The special cases are intended to reduce the runtime overheadof testing | ||
// for the io.ByteWriter interface for common types. Type assertions on a | ||
// concrete type is just a pointer comparison, instead of requiring a | ||
// complex lookup in the type metadata. | ||
switch x := w.w.(type) { | ||
case *bytes.Buffer: | ||
return x.WriteByte(b) | ||
case *bufio.Writer: | ||
return x.WriteByte(b) | ||
case io.ByteWriter: | ||
return x.WriteByte(b) | ||
default: | ||
w.b[0] = b | ||
return w.write(w.b[:1]) | ||
} | ||
} |
Oops, something went wrong.