Skip to content

Commit

Permalink
fix: generating code in dynamicgo repo
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaost committed Jul 31, 2024
1 parent aada66a commit 88ecb45
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 28 deletions.
17 changes: 11 additions & 6 deletions generator/fastgo/codewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,29 @@ package fastgo
import (
"bytes"
"fmt"
"path"
"strings"
)

type codewriter struct {
*bytes.Buffer

pkgs map[string]struct{}
pkgs map[string]string // import -> alias
}

func newCodewriter() *codewriter {
return &codewriter{
Buffer: &bytes.Buffer{},
pkgs: make(map[string]struct{}),
pkgs: make(map[string]string),
}
}

func (w *codewriter) usePkg(s string) {
w.pkgs[s] = struct{}{}
func (w *codewriter) UsePkg(s, a string) {
if path.Base(s) == a {
w.pkgs[s] = ""
} else {
w.pkgs[s] = a
}
}

func (w *codewriter) Imports() string {
Expand All @@ -63,7 +68,7 @@ func (w *codewriter) Imports() string {

// only imports one pkg?
if len(pp0) == 1 {
return fmt.Sprintf("import %q", pp0[0])
return fmt.Sprintf("import %s %q", w.pkgs[pp0[0]], pp0[0])
}

// more than one imports
Expand All @@ -73,7 +78,7 @@ func (w *codewriter) Imports() string {
if p == "" {
fmt.Fprintln(s, "")
} else {
fmt.Fprintf(s, " %q\n", p)
fmt.Fprintf(s, "%s %q\n", w.pkgs[p], p)
}
}
fmt.Fprintln(s, ")")
Expand Down
36 changes: 33 additions & 3 deletions generator/fastgo/fastgo.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (g *FastGoBackend) Generate(req *plugin.Request, log backend.LogFunc) *plug
}
g.req = req
g.log = log
g.utils = golang.NewCodeUtils(log)
g.utils = g.GoBackend.GetCoreUtils()
var trees chan *parser.Thrift
if req.Recursive {
trees = req.AST.DepthFirstSearch()
Expand Down Expand Up @@ -105,10 +105,20 @@ func (g *FastGoBackend) GenerateOne(ast *parser.Thrift) (*plugin.Generated, erro
g.generateStruct(w, scope, s)
}
for _, s := range scope.Unions() {
_ = s
g.generateStruct(w, scope, s)
}
for _, s := range scope.Exceptions() {
_ = s
g.generateStruct(w, scope, s)
}
for _, ss := range scope.Services() {
for _, f := range ss.Functions() {
if s := f.ArgType(); s != nil {
g.generateStruct(w, scope, s)
}
if s := f.ResType(); s != nil {
g.generateStruct(w, scope, s)
}
}
}

ret := &plugin.Generated{}
Expand All @@ -123,11 +133,31 @@ func (g *FastGoBackend) GenerateOne(ast *parser.Thrift) (*plugin.Generated, erro
fmt.Fprintf(c, "%s\npackage %s\n\n", fixedFileHeader, packageName)

// Imports
unusedProtect := false
for _, incl := range scope.Includes() {
if incl == nil { // TODO(liyun.339): fix this
continue
}
unusedProtect = true
w.UsePkg(incl.ImportPath, incl.PackageName)
}
if len(w.pkgs) > 0 {
c.WriteString(w.Imports())
}
c.WriteByte('\n')

// Unused protects
if unusedProtect {
fmt.Fprintln(c, "var (")
for _, incl := range scope.Includes() {
if incl == nil { // TODO(liyun.339): fix this
continue
}
fmt.Fprintf(c, "_ = %s.KitexUnusedProtection\n", incl.PackageName)
}
fmt.Fprintln(c, ")")
}

// Methods
c.Write(w.Bytes())

Expand Down
7 changes: 4 additions & 3 deletions generator/fastgo/gen_blength.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (g *FastGoBackend) genBLength(w *codewriter, scope *golang.Scope, s *golang
// - off is the counter of BLength

// func definition
w.usePkg("github.com/cloudwego/gopkg/protocol/thrift")
w.UsePkg("github.com/cloudwego/gopkg/protocol/thrift", "")
w.f("func (p *%s) BLength() int {", s.GoName())

// case nil, STOP
Expand All @@ -40,7 +40,8 @@ func (g *FastGoBackend) genBLength(w *codewriter, scope *golang.Scope, s *golang
w.f("off := 0")

// fields
for _, f := range s.Fields() {
ff := getSortedFields(s)
for _, f := range ff {
rwctx, err := g.utils.MkRWCtx(scope, f)
if err != nil {
// never goes here, should fail early in generator/golang pkg
Expand Down Expand Up @@ -167,7 +168,7 @@ func genBLengthMap(w *codewriter, rwctx *golang.ReadWriteContext, varname string
genBLengthAny(w, rwctx.ValCtx, tmpv, depth+1)
w.f("}")
} else if vsz > 0 {
w.f("off += %s * %d", varname, vsz)
w.f("off += len(%s) * %d", varname, vsz)
w.f("for %s, _ := range %s {", tmpk, varname)
genBLengthAny(w, rwctx.KeyCtx, tmpk, depth+1)
w.f("}")
Expand Down
20 changes: 14 additions & 6 deletions generator/fastgo/gen_fastread.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ func (g *FastGoBackend) genFastRead(w *codewriter, scope *golang.Scope, s *golan
// Instead of using consts for vars above, would like to use the names directly making code clear

// func definition
w.usePkg("github.com/cloudwego/gopkg/protocol/thrift")
w.UsePkg("github.com/cloudwego/gopkg/protocol/thrift", "")
w.f("func (p *%s) FastRead(b []byte) (off int, err error) {", s.GoName())
w.f("var ftyp thrift.TType")
w.f("var fid int16")
w.f("var l int")

isset := newBitsetCodeGen("isset", "uint8")
ff := getSortedFields(s)
hasEnum := false
ff := getSortedFields(s)
for _, f := range ff {
if f.Type.Category == parser.Category_Enum {
hasEnum = true
Expand Down Expand Up @@ -100,12 +100,14 @@ func (g *FastGoBackend) genFastRead(w *codewriter, scope *golang.Scope, s *golan

w.f("return") // no error

w.usePkg("fmt")
w.UsePkg("fmt", "")
w.f("ReadFieldBeginError:")
w.f(`return off, thrift.PrependError(fmt.Sprintf("%%T read field begin error: ", p), err)`)

w.f("ReadFieldError:")
w.f(`return off, thrift.PrependError(fmt.Sprintf("%%T read field %%d '%%s' error: ", p, fid, fieldIDToName_%s[fid]), err)`, s.GoName())
if len(ff) > 0 { // fix `label ReadFieldError defined and not used`
w.f("ReadFieldError:")
w.f(`return off, thrift.PrependError(fmt.Sprintf("%%T read field %%d '%%s' error: ", p, fid, fieldIDToName_%s[fid]), err)`, s.GoName())
}

w.f("SkipFieldError:")
w.f(`return off, thrift.PrependError(fmt.Sprintf("%%T skip field %%d type %%d error: ", p, fid, ftyp), err)`)
Expand Down Expand Up @@ -311,7 +313,13 @@ func genFastReadMap(w *codewriter, rwctx *golang.ReadWriteContext, varname strin

w.f("%s = make(%s, %s)", varname, rwctx.TypeName, tmpsize)
w.f("for %s := 0; %s < %s; %s++ {", tmpi, tmpi, tmpsize, tmpi)
w.f("var %s %s", tmpk, rwctx.KeyCtx.TypeName)
if rwctx.KeyCtx.TypeID == "Struct" && !rwctx.KeyCtx.IsPointer {
// hotfix for struct, it's always pointer for keys
// remove this check after generator/gopkg fix it
w.f("var %s *%s", tmpk, rwctx.KeyCtx.TypeName)
} else {
w.f("var %s %s", tmpk, rwctx.KeyCtx.TypeName)
}
w.f("var %s %s", tmpv, rwctx.ValCtx.TypeName)
genFastReadAny(w, rwctx.KeyCtx, tmpk, depth+1)
genFastReadAny(w, rwctx.ValCtx, tmpv, depth+1)
Expand Down
21 changes: 11 additions & 10 deletions generator/fastgo/gen_fastwrite.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (g *FastGoBackend) genFastWrite(w *codewriter, scope *golang.Scope, s *gola
// - off is the offset of b

// func definition
w.usePkg("github.com/cloudwego/gopkg/protocol/thrift")
w.UsePkg("github.com/cloudwego/gopkg/protocol/thrift", "")
w.f("func (p *%s) FastWrite(b []byte) int { return p.FastWriteNocopy(b, nil) }\n\n", s.GoName())
w.f("func (p *%s) FastWriteNocopy(b []byte, w thrift.NocopyWriter) int {", s.GoName())

Expand All @@ -42,7 +42,8 @@ func (g *FastGoBackend) genFastWrite(w *codewriter, scope *golang.Scope, s *gola
w.f("off := 0")

// fields
for _, f := range s.Fields() {
ff := getSortedFields(s)
for _, f := range ff {
rwctx, err := g.utils.MkRWCtx(scope, f)
if err != nil {
// never goes here, should fail early in generator/golang pkg
Expand Down Expand Up @@ -82,7 +83,7 @@ func genFastWriteField(w *codewriter, rwctx *golang.ReadWriteContext, f *golang.
}

// field header
w.usePkg("encoding/binary")
w.UsePkg("encoding/binary", "")
w.f("b[off] = %d", category2ThriftWireType[f.Type.Category])
w.f("binary.BigEndian.PutUint16(b[off+1:], %d) ", f.ID)
w.f("off += 3")
Expand Down Expand Up @@ -125,7 +126,7 @@ func genFastWriteAny(w *codewriter, rwctx *golang.ReadWriteContext, varname stri
func genFastWriteBool(w *codewriter, pointer bool, varname string) {
// for bool, the underlying byte of true is always 1, and 0 for false
// which is same as thrift binary protocol
w.usePkg("unsafe")
w.UsePkg("unsafe", "")
w.f("b[off] = *((*byte)(unsafe.Pointer(%s)))", varnamePtr(pointer, varname))
w.f("off++")
}
Expand All @@ -136,25 +137,25 @@ func genFastWriteByte(w *codewriter, pointer bool, varname string) {
}

func genFastWriteDouble(w *codewriter, pointer bool, varname string) {
w.usePkg("unsafe")
w.UsePkg("unsafe", "")
w.f("binary.BigEndian.PutUint64(b[off:], *(*uint64)(unsafe.Pointer(%s)))", varnamePtr(pointer, varname))
w.f("off += 8")
}

func genFastWriteInt16(w *codewriter, pointer bool, varname string) {
w.usePkg("encoding/binary")
w.UsePkg("encoding/binary", "")
w.f("binary.BigEndian.PutUint16(b[off:], uint16(%s))", varnameVal(pointer, varname))
w.f("off += 2")
}

func genFastWriteInt32(w *codewriter, pointer bool, varname string) {
w.usePkg("encoding/binary")
w.UsePkg("encoding/binary", "")
w.f("binary.BigEndian.PutUint32(b[off:], uint32(%s))", varnameVal(pointer, varname))
w.f("off += 4")
}

func genFastWriteInt64(w *codewriter, pointer bool, varname string) {
w.usePkg("encoding/binary")
w.UsePkg("encoding/binary", "")
w.f("binary.BigEndian.PutUint64(b[off:], uint64(%s))", varnameVal(pointer, varname))
w.f("off += 8")
}
Expand All @@ -176,7 +177,7 @@ func genFastWriteStruct(w *codewriter, rwctx *golang.ReadWriteContext, varname s
func genFastWriteList(w *codewriter, rwctx *golang.ReadWriteContext, varname string, depth int) {
rwctx = rwctx.ValCtx
t := rwctx.Type
w.usePkg("encoding/binary")
w.UsePkg("encoding/binary", "")
// list header
w.f("b[off] = %d", category2ThriftWireType[t.Category])
w.f("binary.BigEndian.PutUint32(b[off+1:], uint32(len(%s)))", varname)
Expand All @@ -197,7 +198,7 @@ func genFastWriteMap(w *codewriter, rwctx *golang.ReadWriteContext, varname stri
kt := t.KeyType
vt := t.ValueType
// map header
w.usePkg("encoding/binary")
w.UsePkg("encoding/binary", "")
w.f("b[off] = %d", category2ThriftWireType[kt.Category])
w.f("b[off+1] = %d", category2ThriftWireType[vt.Category])
w.f("binary.BigEndian.PutUint32(b[off+2:], uint32(len(%s)))", varname)
Expand Down
4 changes: 4 additions & 0 deletions generator/golang/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ func (g *GoBackend) Generate(req *plugin.Request, log backend.LogFunc) *plugin.R
return g.buildResponse()
}

func (g *GoBackend) GetCoreUtils() *CodeUtils {
return g.utils
}

func (g *GoBackend) prepareUtilities() {
if g.err != nil {
return
Expand Down

0 comments on commit 88ecb45

Please sign in to comment.