Skip to content

Commit

Permalink
refactor(fastgo): new append method (#234)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaost authored Oct 29, 2024
1 parent 9e52336 commit 540d376
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 90 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ thriftgo:
go install

testall: thriftgo
@for d in test/*; do $(MAKE) -C $$d; done
@set -e; for d in test/*; do $(MAKE) -C $$d; done

clean:
rm -rf $(COV_PROF) $(IDL)
Expand Down
15 changes: 8 additions & 7 deletions generator/fastgo/gen_blength.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ func genBLengthField(w *codewriter, rwctx *golang.ReadWriteContext, f *golang.Fi
}

// field header
w.f("off += 3") // type + fid
// 3 will be added in genBLengthAny
// w.f("off += 3")

// field value
genBLengthAny(w, rwctx, varname, 0)
Expand All @@ -89,7 +90,7 @@ func genBLengthField(w *codewriter, rwctx *golang.ReadWriteContext, f *golang.Fi
func genBLengthAny(w *codewriter, rwctx *golang.ReadWriteContext, varname string, depth int) {
t := rwctx.Type
if sz := category2WireSize[t.Category]; sz > 0 {
w.f("off += %d", sz)
w.f("off += 3 + %d", sz)
return
}
pointer := rwctx.IsPointer
Expand All @@ -107,22 +108,22 @@ func genBLengthAny(w *codewriter, rwctx *golang.ReadWriteContext, varname string

func genBLengthBinary(w *codewriter, pointer bool, varname string) {
varname = varnameVal(pointer, varname)
w.f("off += 4 + len(%s)", varname)
w.f("off += 3 + 4 + len(%s)", varname)
}

func genBLengthString(w *codewriter, pointer bool, varname string) {
varname = varnameVal(pointer, varname)
w.f("off += 4 + len(%s)", varname)
w.f("off += 3 + 4 + len(%s)", varname)
}

func genBLengthStruct(w *codewriter, _ *golang.ReadWriteContext, varname string) {
w.f("off += %s.BLength()", varname)
w.f("off += 3 + %s.BLength()", varname)
}

func genBLengthList(w *codewriter, rwctx *golang.ReadWriteContext, varname string, depth int) {
t := rwctx.Type
// list header
w.f("off += 5")
w.f("off += 3 + 5")

// if element is basic type like int32, we can speed up the calc by sizeof(int32) * len(l)
if sz := category2WireSize[t.ValueType.Category]; sz > 0 { // fast path for less code
Expand All @@ -146,7 +147,7 @@ func genBLengthMap(w *codewriter, rwctx *golang.ReadWriteContext, varname string
vt := t.ValueType

// map header
w.f("off += 6")
w.f("off += 3 + 6")

// iteration tmp var
tmpk := "k"
Expand Down
9 changes: 6 additions & 3 deletions generator/fastgo/gen_fastread.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,18 @@ func (g *FastGoBackend) genFastRead(w *codewriter, scope *golang.Scope, s *golan

if len(ff) > 0 { // fix `label ReadFieldError defined and not used`
w.f("ReadFieldError:")
w.f(`return off, thrift.PrependError(fmt.Sprintf("%%T read field %%d '%%s' error: ", p, fid, fieldIDToName_%s[fid]), err)`, s.GoName())
w.f(`return off, thrift.PrependError(
fmt.Sprintf("%%T read field %%d '%%s' error: ", p, fid, fieldIDToName_%s[fid]), err)`, s.GoName())
}

w.f("SkipFieldError:")
w.f(`return off, thrift.PrependError(fmt.Sprintf("%%T skip field %%d type %%d error: ", p, fid, ftyp), err)`)
w.f(`return off, thrift.PrependError(
fmt.Sprintf("%%T skip field %%d type %%d error: ", p, fid, ftyp), err)`)

if isset.Len() > 0 {
w.f("RequiredFieldNotSetError:")
w.f(`return off, thrift.NewProtocolException(thrift.INVALID_DATA, fmt.Sprintf("required field %%s is not set", fieldIDToName_%s[fid]))`, s.GoName())
w.f(`return off, thrift.NewProtocolException(thrift.INVALID_DATA,
fmt.Sprintf("required field %%s is not set", fieldIDToName_%s[fid]))`, s.GoName())
}

// end of func definition
Expand Down
149 changes: 70 additions & 79 deletions generator/fastgo/gen_fastwrite.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,39 @@ import (
"github.com/cloudwego/thriftgo/parser"
)

const nocopyWriteThreshold = 4096

func (g *FastGoBackend) genFastWrite(w *codewriter, scope *golang.Scope, s *golang.StructLike) {
w.UsePkg("github.com/cloudwego/gopkg/protocol/thrift", "")
w.f("func (p *%s) FastWrite(b []byte) int { return p.FastWriteNocopy(b, nil) }\n\n", s.GoName())

w.f("func (p *%s) FastWriteNocopy(b []byte, w thrift.NocopyWriter) (n int) {", s.GoName())
w.f(`if n = len(p.FastAppend(b[:0])); n > len(b) {`)
w.f(`panic ("buffer overflow. concurrency issue?")`)
w.f(`}`)
w.f(`return`)
w.f("}\n\n") // end of FastWriteNocopy

g.genFastAppend(w, scope, s)
}

func (g *FastGoBackend) genFastAppend(w *codewriter, scope *golang.Scope, s *golang.StructLike) {
// var conventions:
// - p is the var of pointer to the struct going to be generated
// - b is the buf to write into
// - w is the var of thrift.NocopyWriter
// - off is the offset of b
// - x is the shortcut of thrift.BinaryProtocol

// func definition
w.UsePkg("github.com/cloudwego/gopkg/protocol/thrift", "")
w.f("func (p *%s) FastWrite(b []byte) int { return p.FastWriteNocopy(b, nil) }\n\n", s.GoName())
w.f("func (p *%s) FastWriteNocopy(b []byte, w thrift.NocopyWriter) int {", s.GoName())
w.f("func (p *%s) FastAppend(b []byte) []byte {", s.GoName())
defer w.f("}\n\n")

// case nil, STOP and return
w.f("if p == nil { b[0] = 0; return 1; }")
w.f(`if p == nil { return append(b, 0) }`)

// `off` definition for buf cursor
w.f("off := 0")
// shortcut for encoding
w.f("x := thrift.BinaryProtocol{}")
w.f("_ = x")

// fields
ff := getSortedFields(s)
Expand All @@ -49,24 +65,17 @@ func (g *FastGoBackend) genFastWrite(w *codewriter, scope *golang.Scope, s *gola
// never goes here, should fail early in generator/golang pkg
panic(err)
}
genFastWriteField(w, rwctx, f)
genFastAppendField(w, rwctx, f)
}

// end of field encoding
w.f("") // empty line
w.f("b[off] = 0") // STOP
w.f("return off + 1") // return including the STOP byte

// end of func definition
w.f("}\n\n")
w.f("\nreturn append(b, 0)") // return including the STOP byte
}

func genFastWriteField(w *codewriter, rwctx *golang.ReadWriteContext, f *golang.Field) {
func genFastAppendField(w *codewriter, rwctx *golang.ReadWriteContext, f *golang.Field) {
// the real var name ref to the field
varname := string("p." + f.GoName())

// add comment like // ${FieldName} ${FieldID} ${FieldType}
w.f("\n// %s ID:%d %s", rwctx.Target, f.ID, category2GopkgConsts[f.Type.Category])
w.f("\n// %s", rwctx.Target)

// check skip cases
// only for optional fields
Expand All @@ -83,126 +92,108 @@ func genFastWriteField(w *codewriter, rwctx *golang.ReadWriteContext, f *golang.
}

// field header
w.UsePkg("encoding/binary", "")
w.f("b[off] = %d", category2ThriftWireType[f.Type.Category])
w.f("binary.BigEndian.PutUint16(b[off+1:], %d) ", f.ID)
w.f("off += 3")
w.f("b = append(b, %d, %d, %d)", // AppendFieldBegin
category2ThriftWireType[f.Type.Category], byte(f.ID>>8), byte(f.ID))

// field value
genFastWriteAny(w, rwctx, varname, 0)

genFastAppendAny(w, rwctx, varname, 0)
}

func genFastWriteAny(w *codewriter, rwctx *golang.ReadWriteContext, varname string, depth int) {
func genFastAppendAny(w *codewriter, rwctx *golang.ReadWriteContext, varname string, depth int) {
t := rwctx.Type
pointer := rwctx.IsPointer
switch t.Category {
case parser.Category_Bool:
genFastWriteBool(w, pointer, varname)
genFastAppendBool(w, pointer, varname)
case parser.Category_Byte:
genFastWriteByte(w, pointer, varname)
genFastAppendByte(w, pointer, varname)
case parser.Category_I16:
genFastWriteInt16(w, pointer, varname)
genFastAppendInt16(w, pointer, varname)
case parser.Category_I32, parser.Category_Enum:
genFastWriteInt32(w, pointer, varname)
genFastAppendInt32(w, pointer, varname)
case parser.Category_I64:
genFastWriteInt64(w, pointer, varname)
genFastAppendInt64(w, pointer, varname)
case parser.Category_Double:
genFastWriteDouble(w, pointer, varname)
genFastAppendDouble(w, pointer, varname)
case parser.Category_String:
genFastWriteString(w, pointer, varname)
genFastAppendString(w, pointer, varname)
case parser.Category_Binary:
genFastWriteBinary(w, pointer, varname)
genFastAppendBinary(w, pointer, varname)
case parser.Category_Map:
genFastWriteMap(w, rwctx, varname, depth)
genFastAppendMap(w, rwctx, varname, depth)
case parser.Category_List, parser.Category_Set:
genFastWriteList(w, rwctx, varname, depth)
genFastAppendList(w, rwctx, varname, depth)
case parser.Category_Struct, parser.Category_Union, parser.Category_Exception:
// TODO: fix for parser.Category_Union? must only one field set
genFastWriteStruct(w, rwctx, varname)
genFastAppendStruct(w, rwctx, varname)
}
}

func genFastWriteBool(w *codewriter, pointer bool, varname string) {
func genFastAppendBool(w *codewriter, pointer bool, varname string) {
// for bool, the underlying byte of true is always 1, and 0 for false
// which is same as thrift binary protocol
w.UsePkg("unsafe", "")
w.f("b[off] = *((*byte)(unsafe.Pointer(%s)))", varnamePtr(pointer, varname))
w.f("off++")
w.f("b = append(b, *(*byte)(unsafe.Pointer(%s)))", varnamePtr(pointer, varname))
}

func genFastWriteByte(w *codewriter, pointer bool, varname string) {
w.f("b[off] = byte(%s)", varnameVal(pointer, varname))
w.f("off++")
func genFastAppendByte(w *codewriter, pointer bool, varname string) {
w.f("b = append(b, byte(%s))", varnameVal(pointer, varname))
}

func genFastWriteDouble(w *codewriter, pointer bool, varname string) {
w.UsePkg("unsafe", "")
w.f("binary.BigEndian.PutUint64(b[off:], *(*uint64)(unsafe.Pointer(%s)))", varnamePtr(pointer, varname))
w.f("off += 8")
func genFastAppendDouble(w *codewriter, pointer bool, varname string) {
w.f("b = x.AppendDouble(b, float64(%s))", varnameVal(pointer, varname))
}

func genFastWriteInt16(w *codewriter, pointer bool, varname string) {
w.UsePkg("encoding/binary", "")
w.f("binary.BigEndian.PutUint16(b[off:], uint16(%s))", varnameVal(pointer, varname))
w.f("off += 2")
func genFastAppendInt16(w *codewriter, pointer bool, varname string) {
w.f("b = x.AppendI16(b, int16(%s))", varnameVal(pointer, varname))
}

func genFastWriteInt32(w *codewriter, pointer bool, varname string) {
w.UsePkg("encoding/binary", "")
w.f("binary.BigEndian.PutUint32(b[off:], uint32(%s))", varnameVal(pointer, varname))
w.f("off += 4")
func genFastAppendInt32(w *codewriter, pointer bool, varname string) {
w.f("b = x.AppendI32(b, int32(%s))", varnameVal(pointer, varname))
}

func genFastWriteInt64(w *codewriter, pointer bool, varname string) {
w.UsePkg("encoding/binary", "")
w.f("binary.BigEndian.PutUint64(b[off:], uint64(%s))", varnameVal(pointer, varname))
w.f("off += 8")
func genFastAppendInt64(w *codewriter, pointer bool, varname string) {
w.f("b = x.AppendI64(b, int64(%s))", varnameVal(pointer, varname))
}

func genFastWriteBinary(w *codewriter, pointer bool, varname string) {
func genFastAppendBinary(w *codewriter, pointer bool, varname string) {
varname = varnameVal(pointer, varname)
w.f("off += thrift.Binary.WriteBinaryNocopy(b[off:], w, %s)", varname)
w.f("b = x.AppendI32(b, int32(len(%s)))", varname)
w.f("b = append(b, %s...)", varname)
}

func genFastWriteString(w *codewriter, pointer bool, varname string) {
varname = varnameVal(pointer, varname)
w.f("off += thrift.Binary.WriteStringNocopy(b[off:], w, %s)", varname)
func genFastAppendString(w *codewriter, pointer bool, varname string) {
genFastAppendBinary(w, pointer, varname)
}

func genFastWriteStruct(w *codewriter, rwctx *golang.ReadWriteContext, varname string) {
w.f("off += %s.FastWriteNocopy(b[off:], w)", varname)
func genFastAppendStruct(w *codewriter, rwctx *golang.ReadWriteContext, varname string) {
w.f("b = %s.FastAppend(b)", varname)
}

func genFastWriteList(w *codewriter, rwctx *golang.ReadWriteContext, varname string, depth int) {
func genFastAppendList(w *codewriter, rwctx *golang.ReadWriteContext, varname string, depth int) {
rwctx = rwctx.ValCtx
t := rwctx.Type
w.UsePkg("encoding/binary", "")

// list header
w.f("b[off] = %d", category2ThriftWireType[t.Category])
w.f("binary.BigEndian.PutUint32(b[off+1:], uint32(len(%s)))", varname)
w.f("off += 5")
w.f("b = x.AppendListBegin(b, %s, len(%s))", category2GopkgConsts[t.Category], varname)

// iteration tmp var
tmpv := "v"
if depth > 0 { // avoid redeclared vars
tmpv = "v" + strconv.Itoa(depth-1)
}
w.f("for _, %s := range %s {", tmpv, varname)
genFastWriteAny(w, rwctx, tmpv, depth+1)
genFastAppendAny(w, rwctx, tmpv, depth+1)
w.f("}")
}

func genFastWriteMap(w *codewriter, rwctx *golang.ReadWriteContext, varname string, depth int) {
func genFastAppendMap(w *codewriter, rwctx *golang.ReadWriteContext, varname string, depth int) {
t := rwctx.Type
kt := t.KeyType
vt := t.ValueType
// map header
w.UsePkg("encoding/binary", "")
w.f("b[off] = %d", category2ThriftWireType[kt.Category])
w.f("b[off+1] = %d", category2ThriftWireType[vt.Category])
w.f("binary.BigEndian.PutUint32(b[off+2:], uint32(len(%s)))", varname)
w.f("off += 6")
w.f("b = x.AppendMapBegin(b, %s, %s, len(%s))",
category2GopkgConsts[kt.Category], category2GopkgConsts[vt.Category], varname)

// iteration tmp var
tmpk := "k"
Expand All @@ -212,7 +203,7 @@ func genFastWriteMap(w *codewriter, rwctx *golang.ReadWriteContext, varname stri
tmpv = "v" + strconv.Itoa(depth-1)
}
w.f("for %s, %s := range %s {", tmpk, tmpv, varname)
genFastWriteAny(w, rwctx.KeyCtx, tmpk, depth+1)
genFastWriteAny(w, rwctx.ValCtx, tmpv, depth+1)
genFastAppendAny(w, rwctx.KeyCtx, tmpk, depth+1)
genFastAppendAny(w, rwctx.ValCtx, tmpv, depth+1)
w.f("}")
}
18 changes: 18 additions & 0 deletions test/fastgo/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2024 CloudWeGo Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
.PHONY: all

all:
bash -x ./run.sh
21 changes: 21 additions & 0 deletions test/fastgo/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2024 CloudWeGo Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
set -e
thriftgo -g fastgo:no_default_serdes=true,gen_setter=true -o=. ./testdata.thrift
cd testdata
rm -f go.mod
go mod init thriftgo/test/fastgo/testdata
go mod tidy
go build -v ./...
Loading

0 comments on commit 540d376

Please sign in to comment.