diff --git a/generator/fastgo/codewriter.go b/generator/fastgo/codewriter.go index b307e07..7fdb733 100644 --- a/generator/fastgo/codewriter.go +++ b/generator/fastgo/codewriter.go @@ -19,24 +19,29 @@ package fastgo import ( "bytes" "fmt" + "path" "strings" ) type codewriter struct { *bytes.Buffer - pkgs map[string]struct{} + pkgs map[string]string // import -> alias } func newCodewriter() *codewriter { return &codewriter{ Buffer: &bytes.Buffer{}, - pkgs: make(map[string]struct{}), + pkgs: make(map[string]string), } } -func (w *codewriter) usePkg(s string) { - w.pkgs[s] = struct{}{} +func (w *codewriter) UsePkg(s, a string) { + if path.Base(s) == a { + w.pkgs[s] = "" + } else { + w.pkgs[s] = a + } } func (w *codewriter) Imports() string { @@ -63,7 +68,7 @@ func (w *codewriter) Imports() string { // only imports one pkg? if len(pp0) == 1 { - return fmt.Sprintf("import %q", pp0[0]) + return fmt.Sprintf("import %s %q", w.pkgs[pp0[0]], pp0[0]) } // more than one imports @@ -73,7 +78,7 @@ func (w *codewriter) Imports() string { if p == "" { fmt.Fprintln(s, "") } else { - fmt.Fprintf(s, " %q\n", p) + fmt.Fprintf(s, "%s %q\n", w.pkgs[p], p) } } fmt.Fprintln(s, ")") diff --git a/generator/fastgo/fastgo.go b/generator/fastgo/fastgo.go index 5aa17e7..bd04ff9 100644 --- a/generator/fastgo/fastgo.go +++ b/generator/fastgo/fastgo.go @@ -55,7 +55,7 @@ func (g *FastGoBackend) Generate(req *plugin.Request, log backend.LogFunc) *plug } g.req = req g.log = log - g.utils = golang.NewCodeUtils(log) + g.utils = g.GoBackend.GetCoreUtils() var trees chan *parser.Thrift if req.Recursive { trees = req.AST.DepthFirstSearch() @@ -105,10 +105,20 @@ func (g *FastGoBackend) GenerateOne(ast *parser.Thrift) (*plugin.Generated, erro g.generateStruct(w, scope, s) } for _, s := range scope.Unions() { - _ = s + g.generateStruct(w, scope, s) } for _, s := range scope.Exceptions() { - _ = s + g.generateStruct(w, scope, s) + } + for _, ss := range scope.Services() { + for _, f := range ss.Functions() { + if s := f.ArgType(); s != nil { + g.generateStruct(w, scope, s) + } + if s := f.ResType(); s != nil { + g.generateStruct(w, scope, s) + } + } } ret := &plugin.Generated{} @@ -123,11 +133,31 @@ func (g *FastGoBackend) GenerateOne(ast *parser.Thrift) (*plugin.Generated, erro fmt.Fprintf(c, "%s\npackage %s\n\n", fixedFileHeader, packageName) // Imports + unusedProtect := false + for _, incl := range scope.Includes() { + if incl == nil { // TODO(liyun.339): fix this + continue + } + unusedProtect = true + w.UsePkg(incl.ImportPath, incl.PackageName) + } if len(w.pkgs) > 0 { c.WriteString(w.Imports()) } c.WriteByte('\n') + // Unused protects + if unusedProtect { + fmt.Fprintln(c, "var (") + for _, incl := range scope.Includes() { + if incl == nil { // TODO(liyun.339): fix this + continue + } + fmt.Fprintf(c, "_ = %s.KitexUnusedProtection\n", incl.PackageName) + } + fmt.Fprintln(c, ")") + } + // Methods c.Write(w.Bytes()) diff --git a/generator/fastgo/gen_blength.go b/generator/fastgo/gen_blength.go index d7dab46..19b9c3e 100644 --- a/generator/fastgo/gen_blength.go +++ b/generator/fastgo/gen_blength.go @@ -31,7 +31,7 @@ func (g *FastGoBackend) genBLength(w *codewriter, scope *golang.Scope, s *golang // - off is the counter of BLength // func definition - w.usePkg("github.com/cloudwego/gopkg/protocol/thrift") + w.UsePkg("github.com/cloudwego/gopkg/protocol/thrift", "") w.f("func (p *%s) BLength() int {", s.GoName()) // case nil, STOP @@ -40,7 +40,8 @@ func (g *FastGoBackend) genBLength(w *codewriter, scope *golang.Scope, s *golang w.f("off := 0") // fields - for _, f := range s.Fields() { + ff := getSortedFields(s) + for _, f := range ff { rwctx, err := g.utils.MkRWCtx(scope, f) if err != nil { // never goes here, should fail early in generator/golang pkg @@ -167,7 +168,7 @@ func genBLengthMap(w *codewriter, rwctx *golang.ReadWriteContext, varname string genBLengthAny(w, rwctx.ValCtx, tmpv, depth+1) w.f("}") } else if vsz > 0 { - w.f("off += %s * %d", varname, vsz) + w.f("off += len(%s) * %d", varname, vsz) w.f("for %s, _ := range %s {", tmpk, varname) genBLengthAny(w, rwctx.KeyCtx, tmpk, depth+1) w.f("}") diff --git a/generator/fastgo/gen_fastread.go b/generator/fastgo/gen_fastread.go index 5bf3fda..7da4e1a 100644 --- a/generator/fastgo/gen_fastread.go +++ b/generator/fastgo/gen_fastread.go @@ -38,15 +38,15 @@ func (g *FastGoBackend) genFastRead(w *codewriter, scope *golang.Scope, s *golan // Instead of using consts for vars above, would like to use the names directly making code clear // func definition - w.usePkg("github.com/cloudwego/gopkg/protocol/thrift") + w.UsePkg("github.com/cloudwego/gopkg/protocol/thrift", "") w.f("func (p *%s) FastRead(b []byte) (off int, err error) {", s.GoName()) w.f("var ftyp thrift.TType") w.f("var fid int16") w.f("var l int") isset := newBitsetCodeGen("isset", "uint8") - ff := getSortedFields(s) hasEnum := false + ff := getSortedFields(s) for _, f := range ff { if f.Type.Category == parser.Category_Enum { hasEnum = true @@ -100,12 +100,14 @@ func (g *FastGoBackend) genFastRead(w *codewriter, scope *golang.Scope, s *golan w.f("return") // no error - w.usePkg("fmt") + w.UsePkg("fmt", "") w.f("ReadFieldBeginError:") w.f(`return off, thrift.PrependError(fmt.Sprintf("%%T read field begin error: ", p), err)`) - w.f("ReadFieldError:") - w.f(`return off, thrift.PrependError(fmt.Sprintf("%%T read field %%d '%%s' error: ", p, fid, fieldIDToName_%s[fid]), err)`, s.GoName()) + if len(ff) > 0 { // fix `label ReadFieldError defined and not used` + w.f("ReadFieldError:") + w.f(`return off, thrift.PrependError(fmt.Sprintf("%%T read field %%d '%%s' error: ", p, fid, fieldIDToName_%s[fid]), err)`, s.GoName()) + } w.f("SkipFieldError:") w.f(`return off, thrift.PrependError(fmt.Sprintf("%%T skip field %%d type %%d error: ", p, fid, ftyp), err)`) @@ -311,7 +313,13 @@ func genFastReadMap(w *codewriter, rwctx *golang.ReadWriteContext, varname strin w.f("%s = make(%s, %s)", varname, rwctx.TypeName, tmpsize) w.f("for %s := 0; %s < %s; %s++ {", tmpi, tmpi, tmpsize, tmpi) - w.f("var %s %s", tmpk, rwctx.KeyCtx.TypeName) + if rwctx.KeyCtx.TypeID == "Struct" && !rwctx.KeyCtx.IsPointer { + // hotfix for struct, it's always pointer for keys + // remove this check after generator/gopkg fix it + w.f("var %s *%s", tmpk, rwctx.KeyCtx.TypeName) + } else { + w.f("var %s %s", tmpk, rwctx.KeyCtx.TypeName) + } w.f("var %s %s", tmpv, rwctx.ValCtx.TypeName) genFastReadAny(w, rwctx.KeyCtx, tmpk, depth+1) genFastReadAny(w, rwctx.ValCtx, tmpv, depth+1) diff --git a/generator/fastgo/gen_fastwrite.go b/generator/fastgo/gen_fastwrite.go index 25911a6..246d49f 100644 --- a/generator/fastgo/gen_fastwrite.go +++ b/generator/fastgo/gen_fastwrite.go @@ -31,7 +31,7 @@ func (g *FastGoBackend) genFastWrite(w *codewriter, scope *golang.Scope, s *gola // - off is the offset of b // func definition - w.usePkg("github.com/cloudwego/gopkg/protocol/thrift") + w.UsePkg("github.com/cloudwego/gopkg/protocol/thrift", "") w.f("func (p *%s) FastWrite(b []byte) int { return p.FastWriteNocopy(b, nil) }\n\n", s.GoName()) w.f("func (p *%s) FastWriteNocopy(b []byte, w thrift.NocopyWriter) int {", s.GoName()) @@ -42,7 +42,8 @@ func (g *FastGoBackend) genFastWrite(w *codewriter, scope *golang.Scope, s *gola w.f("off := 0") // fields - for _, f := range s.Fields() { + ff := getSortedFields(s) + for _, f := range ff { rwctx, err := g.utils.MkRWCtx(scope, f) if err != nil { // never goes here, should fail early in generator/golang pkg @@ -82,7 +83,7 @@ func genFastWriteField(w *codewriter, rwctx *golang.ReadWriteContext, f *golang. } // field header - w.usePkg("encoding/binary") + w.UsePkg("encoding/binary", "") w.f("b[off] = %d", category2ThriftWireType[f.Type.Category]) w.f("binary.BigEndian.PutUint16(b[off+1:], %d) ", f.ID) w.f("off += 3") @@ -125,7 +126,7 @@ func genFastWriteAny(w *codewriter, rwctx *golang.ReadWriteContext, varname stri func genFastWriteBool(w *codewriter, pointer bool, varname string) { // for bool, the underlying byte of true is always 1, and 0 for false // which is same as thrift binary protocol - w.usePkg("unsafe") + w.UsePkg("unsafe", "") w.f("b[off] = *((*byte)(unsafe.Pointer(%s)))", varnamePtr(pointer, varname)) w.f("off++") } @@ -136,25 +137,25 @@ func genFastWriteByte(w *codewriter, pointer bool, varname string) { } func genFastWriteDouble(w *codewriter, pointer bool, varname string) { - w.usePkg("unsafe") + w.UsePkg("unsafe", "") w.f("binary.BigEndian.PutUint64(b[off:], *(*uint64)(unsafe.Pointer(%s)))", varnamePtr(pointer, varname)) w.f("off += 8") } func genFastWriteInt16(w *codewriter, pointer bool, varname string) { - w.usePkg("encoding/binary") + w.UsePkg("encoding/binary", "") w.f("binary.BigEndian.PutUint16(b[off:], uint16(%s))", varnameVal(pointer, varname)) w.f("off += 2") } func genFastWriteInt32(w *codewriter, pointer bool, varname string) { - w.usePkg("encoding/binary") + w.UsePkg("encoding/binary", "") w.f("binary.BigEndian.PutUint32(b[off:], uint32(%s))", varnameVal(pointer, varname)) w.f("off += 4") } func genFastWriteInt64(w *codewriter, pointer bool, varname string) { - w.usePkg("encoding/binary") + w.UsePkg("encoding/binary", "") w.f("binary.BigEndian.PutUint64(b[off:], uint64(%s))", varnameVal(pointer, varname)) w.f("off += 8") } @@ -176,7 +177,7 @@ func genFastWriteStruct(w *codewriter, rwctx *golang.ReadWriteContext, varname s func genFastWriteList(w *codewriter, rwctx *golang.ReadWriteContext, varname string, depth int) { rwctx = rwctx.ValCtx t := rwctx.Type - w.usePkg("encoding/binary") + w.UsePkg("encoding/binary", "") // list header w.f("b[off] = %d", category2ThriftWireType[t.Category]) w.f("binary.BigEndian.PutUint32(b[off+1:], uint32(len(%s)))", varname) @@ -197,7 +198,7 @@ func genFastWriteMap(w *codewriter, rwctx *golang.ReadWriteContext, varname stri kt := t.KeyType vt := t.ValueType // map header - w.usePkg("encoding/binary") + w.UsePkg("encoding/binary", "") w.f("b[off] = %d", category2ThriftWireType[kt.Category]) w.f("b[off+1] = %d", category2ThriftWireType[vt.Category]) w.f("binary.BigEndian.PutUint32(b[off+2:], uint32(len(%s)))", varname) diff --git a/generator/golang/backend.go b/generator/golang/backend.go index caa29d2..b949f30 100644 --- a/generator/golang/backend.go +++ b/generator/golang/backend.go @@ -107,6 +107,10 @@ func (g *GoBackend) Generate(req *plugin.Request, log backend.LogFunc) *plugin.R return g.buildResponse() } +func (g *GoBackend) GetCoreUtils() *CodeUtils { + return g.utils +} + func (g *GoBackend) prepareUtilities() { if g.err != nil { return