gr_hz/thrift/ast.go
2024-04-30 23:07:06 +08:00

917 lines
26 KiB
Go

/*
* Copyright 2022 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package thrift
import (
"fmt"
"sort"
"strings"
"github.com/cloudwego/thriftgo/generator/golang"
"github.com/cloudwego/thriftgo/generator/golang/styles"
"github.com/cloudwego/thriftgo/parser"
"github.com/cloudwego/thriftgo/semantic"
"golib.gaore.com/GaoreGo/gr_hz/config"
"golib.gaore.com/GaoreGo/gr_hz/generator"
"golib.gaore.com/GaoreGo/gr_hz/generator/model"
"golib.gaore.com/GaoreGo/gr_hz/meta"
"golib.gaore.com/GaoreGo/gr_hz/util"
"golib.gaore.com/GaoreGo/gr_hz/util/logs"
)
/*---------------------------Import-----------------------------*/
func getGoPackage(ast *parser.Thrift, pkgMap map[string]string) string {
filePackage := ast.GetFilename()
if opt, ok := pkgMap[filePackage]; ok {
return opt
} else {
goPackage := ast.GetNamespaceOrReferenceName("go")
if goPackage != "" {
return util.SplitPackage(goPackage, "")
}
// If namespace is not declared, the file name (without the extension) is used as the package name
return util.SplitPackage(filePackage, ".thrift")
}
}
/*---------------------------Service-----------------------------*/
func astToService(ast *parser.Thrift, resolver *Resolver, args *config.Argument) ([]*generator.Service, error) {
ss := ast.GetServices()
out := make([]*generator.Service, 0, len(ss))
var models model.Models
extendServices := getExtendServices(ast)
for _, s := range ss {
// if the service is extended, it is not processed
if extendServices.exist(s.Name) && args.EnableExtends {
logs.Debugf("%s is extended, so skip it\n", s.Name)
continue
}
resolver.ExportReferred(true, false)
service := &generator.Service{
Name: s.GetName(),
}
service.BaseDomain = ""
domainAnno := getAnnotation(s.Annotations, ApiBaseDomain)
if len(domainAnno) == 1 {
if args.CmdType == meta.CmdClient {
service.BaseDomain = domainAnno[0]
}
}
service.ServiceGroup = ""
groupAnno := getAnnotation(s.Annotations, ApiServiceGroup)
if len(groupAnno) == 1 {
if args.CmdType != meta.CmdClient {
service.ServiceGroup = groupAnno[0]
}
}
service.ServiceGenDir = ""
serviceGenDirAnno := getAnnotation(s.Annotations, ApiServiceGenDir)
if len(serviceGenDirAnno) == 1 {
if args.CmdType != meta.CmdClient {
service.ServiceGenDir = serviceGenDirAnno[0]
}
}
ms := s.GetFunctions()
if len(s.Extends) != 0 && args.EnableExtends {
// all the services that are extended to the current service
extendsFuncs, err := getAllExtendFunction(s, ast, resolver, args)
if err != nil {
return nil, fmt.Errorf("parser extend function failed, err=%v", err)
}
ms = append(ms, extendsFuncs...)
}
methods := make([]*generator.HttpMethod, 0, len(ms))
clientMethods := make([]*generator.ClientMethod, 0, len(ms))
servicePathAnno := getAnnotation(s.Annotations, ApiServicePath)
servicePath := ""
if len(servicePathAnno) > 0 {
servicePath = servicePathAnno[0]
}
for _, m := range ms {
rs := getAnnotations(m.Annotations, HttpMethodAnnotations)
if len(rs) == 0 {
continue
}
httpAnnos := httpAnnotations{}
for k, v := range rs {
httpAnnos = append(httpAnnos, httpAnnotation{
method: k,
path: v,
})
}
// turn the map into a slice and sort it to make sure getting the results in the same order every time
sort.Sort(httpAnnos)
handlerOutDir := servicePath
genPaths := getAnnotation(m.Annotations, ApiGenPath)
if len(genPaths) == 1 {
handlerOutDir = genPaths[0]
} else if len(genPaths) > 0 {
return nil, fmt.Errorf("too many 'api.handler_path' for %s", m.Name)
}
hmethod, path := httpAnnos[0].method, httpAnnos[0].path
if len(path) == 0 || path[0] == "" {
return nil, fmt.Errorf("invalid api.%s for %s.%s: %s", hmethod, s.Name, m.Name, path)
}
var reqName, reqRawName, reqPackage string
if len(m.Arguments) >= 1 {
if len(m.Arguments) > 1 {
logs.Warnf("function '%s' has more than one argument, but only the first can be used in hertz now", m.GetName())
}
var err error
reqName, err = resolver.ResolveTypeName(m.Arguments[0].GetType())
if err != nil {
return nil, err
}
if strings.Contains(reqName, ".") && !m.Arguments[0].GetType().Category.IsContainerType() {
// If reqName contains "." , then it must be of the form "pkg.name".
// so reqRawName='name', reqPackage='pkg'
names := strings.Split(reqName, ".")
if len(names) != 2 {
return nil, fmt.Errorf("request name: %s is wrong", reqName)
}
reqRawName = names[1]
reqPackage = names[0]
}
}
var respName, respRawName, respPackage string
if !m.Oneway {
var err error
respName, err = resolver.ResolveTypeName(m.GetFunctionType())
if err != nil {
return nil, err
}
if strings.Contains(respName, ".") && !m.GetFunctionType().Category.IsContainerType() {
names := strings.Split(respName, ".")
if len(names) != 2 {
return nil, fmt.Errorf("response name: %s is wrong", respName)
}
// If respName contains "." , then it must be of the form "pkg.name".
// so respRawName='name', respPackage='pkg'
respRawName = names[1]
respPackage = names[0]
}
}
sr, _ := util.GetFirstKV(getAnnotations(m.Annotations, SerializerTags))
method := &generator.HttpMethod{
Name: util.CamelString(m.GetName()),
HTTPMethod: hmethod,
RequestTypeName: reqName,
RequestTypeRawName: reqRawName,
RequestTypePackage: reqPackage,
ReturnTypeName: respName,
ReturnTypeRawName: respRawName,
ReturnTypePackage: respPackage,
Path: path[0],
Serializer: sr,
OutputDir: handlerOutDir,
GenHandler: true,
// Annotations: m.Annotations,
}
refs := resolver.ExportReferred(false, true)
method.Models = make(map[string]*model.Model, len(refs))
for _, ref := range refs {
if v, ok := method.Models[ref.Model.PackageName]; ok && (v.Package != ref.Model.Package) {
return nil, fmt.Errorf("Package name: %s redeclared in %s and %s ", ref.Model.PackageName, v.Package, ref.Model.Package)
}
method.Models[ref.Model.PackageName] = ref.Model
}
models.MergeMap(method.Models)
methods = append(methods, method)
for idx, anno := range httpAnnos {
for i := 0; i < len(anno.path); i++ {
if idx == 0 && i == 0 { // idx==0 && i==0 has been added above
continue
}
newMethod, err := newHTTPMethod(s, m, method, i, anno)
if err != nil {
return nil, err
}
methods = append(methods, newMethod)
}
}
if args.CmdType == meta.CmdClient {
clientMethod := &generator.ClientMethod{}
clientMethod.HttpMethod = method
rt, err := resolver.ResolveIdentifier(m.Arguments[0].GetType().GetName())
if err != nil {
return nil, err
}
err = parseAnnotationToClient(clientMethod, m.Arguments[0].GetType(), rt)
if err != nil {
return nil, err
}
clientMethods = append(clientMethods, clientMethod)
}
}
service.ClientMethods = clientMethods
service.Methods = methods
service.Models = models
out = append(out, service)
}
return out, nil
}
func newHTTPMethod(s *parser.Service, m *parser.Function, method *generator.HttpMethod, i int, anno httpAnnotation) (*generator.HttpMethod, error) {
newMethod := *method
hmethod, path := anno.method, anno.path
if path[i] == "" {
return nil, fmt.Errorf("invalid api.%s for %s.%s: %s", hmethod, s.Name, m.Name, path[i])
}
newMethod.HTTPMethod = hmethod
newMethod.Path = path[i]
newMethod.GenHandler = false
return &newMethod, nil
}
func parseAnnotationToClient(clientMethod *generator.ClientMethod, p *parser.Type, symbol ResolvedSymbol) error {
if p == nil {
return fmt.Errorf("get type failed for parse annotatoon to client")
}
typeName := p.GetName()
if strings.Contains(typeName, ".") {
ret := strings.Split(typeName, ".")
typeName = ret[len(ret)-1]
}
scope, err := golang.BuildScope(thriftgoUtil, symbol.Scope)
if err != nil {
return fmt.Errorf("can not build scope for %s", p.Name)
}
thriftgoUtil.SetRootScope(scope)
st := scope.StructLike(typeName)
if st == nil {
logs.Infof("the type '%s' for method '%s' is base type, so skip parse client info\n")
return nil
}
var (
hasBodyAnnotation bool
hasFormAnnotation bool
)
for _, field := range st.Fields() {
hasAnnotation := false
isStringFieldType := false
if field.GetType().String() == "string" {
isStringFieldType = true
}
if anno := getAnnotation(field.Annotations, AnnotationQuery); len(anno) > 0 {
hasAnnotation = true
query := checkSnakeName(anno[0])
clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", query, field.GoName().String())
}
if anno := getAnnotation(field.Annotations, AnnotationPath); len(anno) > 0 {
hasAnnotation = true
path := checkSnakeName(anno[0])
if isStringFieldType {
clientMethod.PathParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", path, field.GoName().String())
} else {
clientMethod.PathParamsCode += fmt.Sprintf("%q: fmt.Sprint(req.Get%s()),\n", path, field.GoName().String())
}
}
if anno := getAnnotation(field.Annotations, AnnotationHeader); len(anno) > 0 {
hasAnnotation = true
header := checkSnakeName(anno[0])
if isStringFieldType {
clientMethod.HeaderParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", header, field.GoName().String())
} else {
clientMethod.HeaderParamsCode += fmt.Sprintf("%q: fmt.Sprint(req.Get%s()),\n", header, field.GoName().String())
}
}
if anno := getAnnotation(field.Annotations, AnnotationForm); len(anno) > 0 {
hasAnnotation = true
form := checkSnakeName(anno[0])
hasFormAnnotation = true
if isStringFieldType {
clientMethod.FormValueCode += fmt.Sprintf("%q: req.Get%s(),\n", form, field.GoName().String())
} else {
clientMethod.FormValueCode += fmt.Sprintf("%q: fmt.Sprint(req.Get%s()),\n", form, field.GoName().String())
}
}
if anno := getAnnotation(field.Annotations, AnnotationBody); len(anno) > 0 {
hasAnnotation = true
hasBodyAnnotation = true
}
if anno := getAnnotation(field.Annotations, AnnotationFileName); len(anno) > 0 {
hasAnnotation = true
fileName := checkSnakeName(anno[0])
hasFormAnnotation = true
clientMethod.FormFileCode += fmt.Sprintf("%q: req.Get%s(),\n", fileName, field.GoName().String())
}
if !hasAnnotation && strings.EqualFold(clientMethod.HTTPMethod, "get") {
clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", checkSnakeName(field.GetName()), field.GoName().String())
}
}
clientMethod.BodyParamsCode = meta.SetBodyParam
if hasBodyAnnotation && hasFormAnnotation {
clientMethod.FormValueCode = ""
clientMethod.FormFileCode = ""
}
if !hasBodyAnnotation && hasFormAnnotation {
clientMethod.BodyParamsCode = ""
}
return nil
}
type extendServiceList []string
func (svr extendServiceList) exist(serviceName string) bool {
for _, s := range svr {
if s == serviceName {
return true
}
}
return false
}
func getExtendServices(ast *parser.Thrift) (res extendServiceList) {
for a := range ast.DepthFirstSearch() {
for _, svc := range a.Services {
if len(svc.Extends) > 0 {
res = append(res, svc.Extends)
}
}
}
return
}
func getAllExtendFunction(svc *parser.Service, ast *parser.Thrift, resolver *Resolver, args *config.Argument) (res []*parser.Function, err error) {
if len(svc.Extends) == 0 {
return
}
parts := semantic.SplitType(svc.Extends)
switch len(parts) {
case 1:
if resolver.mainPkg.Ast.Filename == ast.Filename { // extended current service for master IDL
extendSvc, found := ast.GetService(parts[0])
if found {
funcs := extendSvc.GetFunctions()
// determine if it still has extends
extendFuncs, err := getAllExtendFunction(extendSvc, ast, resolver, args)
if err != nil {
return nil, err
}
res = append(res, append(funcs, extendFuncs...)...)
}
return res, nil
} else { // extended current service for other IDL
extendSvc, found := ast.GetService(parts[0])
if found {
base, err := addResolverDependency(resolver, ast, args)
if err != nil {
return nil, err
}
funcs := extendSvc.GetFunctions()
for _, f := range funcs {
processExtendsType(f, base)
}
extendFuncs, err := getAllExtendFunction(extendSvc, ast, resolver, args)
if err != nil {
return nil, err
}
res = append(res, append(funcs, extendFuncs...)...)
}
return res, nil
}
case 2:
refAst, found := ast.GetReference(parts[0])
base, err := addResolverDependency(resolver, refAst, args)
if err != nil {
return nil, err
}
// ff the service extends from other files, it has to resolve the dependencies of other files as well
for _, dep := range refAst.Includes {
_, err := addResolverDependency(resolver, dep.Reference, args)
if err != nil {
return nil, err
}
}
if found {
extendSvc, found := refAst.GetService(parts[1])
if found {
funcs := extendSvc.GetFunctions()
for _, f := range funcs {
processExtendsType(f, base)
}
extendFuncs, err := getAllExtendFunction(extendSvc, refAst, resolver, args)
if err != nil {
return nil, err
}
res = append(res, append(funcs, extendFuncs...)...)
}
}
return res, nil
}
return res, nil
}
func processExtendsType(f *parser.Function, base string) {
// the method of other file is extended, and the package of req/resp needs to be changed
// ex. base.thrift -> Resp Method(Req){}
// base.Resp Method(base.Req){}
if len(f.Arguments) > 0 {
if f.Arguments[0].Type.Category.IsContainerType() {
switch f.Arguments[0].Type.Category {
case parser.Category_Set, parser.Category_List:
if !strings.Contains(f.Arguments[0].Type.ValueType.Name, ".") && f.Arguments[0].Type.ValueType.Category.IsStruct() {
f.Arguments[0].Type.ValueType.Name = base + "." + f.Arguments[0].Type.ValueType.Name
}
case parser.Category_Map:
if !strings.Contains(f.Arguments[0].Type.ValueType.Name, ".") && f.Arguments[0].Type.ValueType.Category.IsStruct() {
f.Arguments[0].Type.ValueType.Name = base + "." + f.Arguments[0].Type.ValueType.Name
}
if !strings.Contains(f.Arguments[0].Type.KeyType.Name, ".") && f.Arguments[0].Type.KeyType.Category.IsStruct() {
f.Arguments[0].Type.KeyType.Name = base + "." + f.Arguments[0].Type.KeyType.Name
}
}
} else {
if !strings.Contains(f.Arguments[0].Type.Name, ".") && f.Arguments[0].Type.Category.IsStruct() {
f.Arguments[0].Type.Name = base + "." + f.Arguments[0].Type.Name
}
}
}
if f.FunctionType.Category.IsContainerType() {
switch f.FunctionType.Category {
case parser.Category_Set, parser.Category_List:
if !strings.Contains(f.FunctionType.ValueType.Name, ".") && f.FunctionType.ValueType.Category.IsStruct() {
f.FunctionType.ValueType.Name = base + "." + f.FunctionType.ValueType.Name
}
case parser.Category_Map:
if !strings.Contains(f.FunctionType.ValueType.Name, ".") && f.FunctionType.ValueType.Category.IsStruct() {
f.FunctionType.ValueType.Name = base + "." + f.FunctionType.ValueType.Name
}
if !strings.Contains(f.FunctionType.KeyType.Name, ".") && f.FunctionType.KeyType.Category.IsStruct() {
f.FunctionType.KeyType.Name = base + "." + f.FunctionType.KeyType.Name
}
}
} else {
if !strings.Contains(f.FunctionType.Name, ".") && f.FunctionType.Category.IsStruct() {
f.FunctionType.Name = base + "." + f.FunctionType.Name
}
}
}
func getUniqueResolveDependentName(name string, resolver *Resolver) string {
rawName := name
for i := 0; i < 10000; i++ {
if _, exist := resolver.deps[name]; !exist {
return name
}
name = rawName + fmt.Sprint(i)
}
return name
}
func addResolverDependency(resolver *Resolver, ast *parser.Thrift, args *config.Argument) (string, error) {
namespace, err := resolver.LoadOne(ast)
if err != nil {
return "", err
}
baseName := util.BaseName(ast.Filename, ".thrift")
if refPkg, exist := resolver.refPkgs[baseName]; !exist {
resolver.deps[baseName] = namespace
} else {
if ast.Filename != refPkg.Ast.Filename {
baseName = getUniqueResolveDependentName(baseName, resolver)
resolver.deps[baseName] = namespace
}
}
pkg := getGoPackage(ast, args.OptPkgMap)
impt := ast.Filename
pkgName := util.SplitPackageName(pkg, "")
pkgName, err = util.GetPackageUniqueName(pkgName)
if err != nil {
return "", err
}
ref := &PackageReference{baseName, impt, &model.Model{
FilePath: ast.Filename,
Package: pkg,
PackageName: pkgName,
}, ast, false}
if _, exist := resolver.refPkgs[baseName]; !exist {
resolver.refPkgs[baseName] = ref
}
return baseName, nil
}
/*---------------------------Model-----------------------------*/
var BaseThrift = parser.Thrift{}
var baseTypes = map[string]string{
"bool": "bool",
"byte": "int8",
"i8": "int8",
"i16": "int16",
"i32": "int32",
"i64": "int64",
"double": "float64",
"string": "string",
"binary": "[]byte",
}
func switchBaseType(typ *parser.Type) *model.Type {
switch typ.Name {
case "bool":
return model.TypeBool
case "byte":
return model.TypeByte
case "i8":
return model.TypeInt8
case "i16":
return model.TypeInt16
case "i32":
return model.TypeInt32
case "i64":
return model.TypeInt64
case "int":
return model.TypeInt
case "double":
return model.TypeFloat64
case "string":
return model.TypeString
case "binary":
return model.TypeBinary
}
return nil
}
func newBaseType(typ *model.Type, cg model.Category) *model.Type {
cyp := *typ
cyp.Category = cg
return &cyp
}
func newStructType(name string, cg model.Category) *model.Type {
return &model.Type{
Name: name,
Scope: nil,
Kind: model.KindStruct,
Category: cg,
Indirect: false,
Extra: nil,
HasNew: true,
}
}
func newEnumType(name string, cg model.Category) *model.Type {
return &model.Type{
Name: name,
Scope: &model.BaseModel,
Kind: model.KindInt,
Category: cg,
}
}
func newFuncType(name string, cg model.Category) *model.Type {
return &model.Type{
Name: name,
Scope: nil,
Kind: model.KindFunc,
Category: cg,
Indirect: false,
Extra: nil,
HasNew: false,
}
}
func (resolver *Resolver) getFieldType(typ *parser.Type) (*model.Type, error) {
if dt, _ := resolver.getBaseType(typ); dt != nil {
return dt, nil
}
sb := resolver.Get(typ.Name)
if sb != nil {
return sb.Type, nil
}
return nil, fmt.Errorf("unknown type: %s", typ.Name)
}
type ResolvedSymbol struct {
Base string
Src string
*Symbol
}
func (rs ResolvedSymbol) Expression() string {
base, err := NameStyle.Identify(rs.Base)
if err != nil {
logs.Warnf("%s naming style for %s failed, fall back to %s, please refer to the variable manually!", NameStyle.Name(), rs.Base, rs.Base)
base = rs.Base
}
// base type no need to do name style
if model.IsBaseType(rs.Type) {
// base type mapping
if val, exist := baseTypes[rs.Base]; exist {
base = val
}
}
if rs.Src != "" {
if !rs.IsValue && model.IsBaseType(rs.Type) {
return base
}
return fmt.Sprintf("%s.%s", rs.Src, base)
}
return base
}
func astToModel(ast *parser.Thrift, rs *Resolver) (*model.Model, error) {
main := rs.mainPkg.Model
if main == nil {
main = new(model.Model)
}
// typedefs
tds := ast.GetTypedefs()
typdefs := make([]model.TypeDef, 0, len(tds))
for _, t := range tds {
td := model.TypeDef{
Scope: main,
Alias: t.Alias,
}
if bt, err := rs.ResolveType(t.Type); bt == nil || err != nil {
return nil, fmt.Errorf("%s has no type definition, error: %s", t.String(), err)
} else {
td.Type = bt
}
typdefs = append(typdefs, td)
}
main.Typedefs = typdefs
// constants
cts := ast.GetConstants()
constants := make([]model.Constant, 0, len(cts))
variables := make([]model.Variable, 0, len(cts))
for _, c := range cts {
ft, err := rs.ResolveType(c.Type)
if err != nil {
return nil, err
}
if ft.Name == model.TypeBaseList.Name || ft.Name == model.TypeBaseMap.Name || ft.Name == model.TypeBaseSet.Name {
resolveValue, err := rs.ResolveConstantValue(c.Value)
if err != nil {
return nil, err
}
vt := model.Variable{
Scope: main,
Name: c.Name,
Type: ft,
Value: resolveValue,
}
variables = append(variables, vt)
} else {
resolveValue, err := rs.ResolveConstantValue(c.Value)
if err != nil {
return nil, err
}
ct := model.Constant{
Scope: main,
Name: c.Name,
Type: ft,
Value: resolveValue,
}
constants = append(constants, ct)
}
}
main.Constants = constants
main.Variables = variables
// Enums
ems := ast.GetEnums()
enums := make([]model.Enum, 0, len(ems))
for _, e := range ems {
em := model.Enum{
Scope: main,
Name: e.GetName(),
GoType: "int64",
}
vs := make([]model.Constant, 0, len(e.Values))
for _, ee := range e.Values {
vs = append(vs, model.Constant{
Scope: main,
Name: ee.Name,
Type: model.TypeInt64,
Value: model.IntExpression{Src: int(ee.Value)},
})
}
em.Values = vs
enums = append(enums, em)
}
main.Enums = enums
// Structs
sts := make([]*parser.StructLike, 0, len(ast.Structs))
sts = append(sts, ast.Structs...)
structs := make([]model.Struct, 0, len(ast.Structs)+len(ast.Unions)+len(ast.Exceptions))
for _, st := range sts {
s := model.Struct{
Scope: main,
Name: st.GetName(),
Category: model.CategoryStruct,
LeadingComments: removeCommentsSlash(st.GetReservedComments()),
}
vs := make([]model.Field, 0, len(st.Fields))
for _, f := range st.Fields {
fieldName, _ := (&styles.ThriftGo{}).Identify(f.Name)
isP, err := isPointer(f, rs)
if err != nil {
return nil, err
}
resolveType, err := rs.ResolveType(f.Type)
if err != nil {
return nil, err
}
field := model.Field{
Scope: &s,
Name: fieldName,
Type: resolveType,
// IsSetDefault: f.IsSetDefault(),
LeadingComments: removeCommentsSlash(f.GetReservedComments()),
IsPointer: isP,
}
err = injectTags(f, &field, true, true)
if err != nil {
return nil, err
}
vs = append(vs, field)
}
checkDuplicatedFileName(vs)
s.Fields = vs
structs = append(structs, s)
}
sts = make([]*parser.StructLike, 0, len(ast.Unions))
sts = append(sts, ast.Unions...)
for _, st := range sts {
s := model.Struct{
Scope: main,
Name: st.GetName(),
Category: model.CategoryUnion,
LeadingComments: removeCommentsSlash(st.GetReservedComments()),
}
vs := make([]model.Field, 0, len(st.Fields))
for _, f := range st.Fields {
fieldName, _ := (&styles.ThriftGo{}).Identify(f.Name)
isP, err := isPointer(f, rs)
if err != nil {
return nil, err
}
resolveType, err := rs.ResolveType(f.Type)
if err != nil {
return nil, err
}
field := model.Field{
Scope: &s,
Name: fieldName,
Type: resolveType,
LeadingComments: removeCommentsSlash(f.GetReservedComments()),
IsPointer: isP,
}
err = injectTags(f, &field, true, true)
if err != nil {
return nil, err
}
vs = append(vs, field)
}
checkDuplicatedFileName(vs)
s.Fields = vs
structs = append(structs, s)
}
sts = make([]*parser.StructLike, 0, len(ast.Exceptions))
sts = append(sts, ast.Exceptions...)
for _, st := range sts {
s := model.Struct{
Scope: main,
Name: st.GetName(),
Category: model.CategoryException,
LeadingComments: removeCommentsSlash(st.GetReservedComments()),
}
vs := make([]model.Field, 0, len(st.Fields))
for _, f := range st.Fields {
fieldName, _ := (&styles.ThriftGo{}).Identify(f.Name)
isP, err := isPointer(f, rs)
if err != nil {
return nil, err
}
resolveType, err := rs.ResolveType(f.Type)
if err != nil {
return nil, err
}
field := model.Field{
Scope: &s,
Name: fieldName,
Type: resolveType,
LeadingComments: removeCommentsSlash(f.GetReservedComments()),
IsPointer: isP,
}
err = injectTags(f, &field, true, true)
if err != nil {
return nil, err
}
vs = append(vs, field)
}
checkDuplicatedFileName(vs)
s.Fields = vs
structs = append(structs, s)
}
main.Structs = structs
// In case of only the service refers another model, therefore scanning service is necessary
ss := ast.GetServices()
var err error
for _, s := range ss {
for _, m := range s.GetFunctions() {
_, err = rs.ResolveType(m.GetFunctionType())
if err != nil {
return nil, err
}
for _, a := range m.GetArguments() {
_, err = rs.ResolveType(a.GetType())
if err != nil {
return nil, err
}
}
}
}
return main, nil
}
// removeCommentsSlash can remove double slash for comments with thrift
func removeCommentsSlash(comments string) string {
if comments == "" {
return ""
}
return comments[2:]
}
func isPointer(f *parser.Field, rs *Resolver) (bool, error) {
typ, err := rs.ResolveType(f.GetType())
if err != nil {
return false, err
}
if typ == nil {
return false, fmt.Errorf("can not get type: %s for %s", f.GetType(), f.GetName())
}
if typ.Kind == model.KindStruct || typ.Kind == model.KindMap || typ.Kind == model.KindSlice {
return false, nil
}
if f.GetRequiredness().IsOptional() {
return true, nil
} else {
return false, nil
}
}
func getNewFieldName(fieldName string, fieldNameSet map[string]bool) string {
if _, ex := fieldNameSet[fieldName]; ex {
fieldName = fieldName + "_"
return getNewFieldName(fieldName, fieldNameSet)
}
return fieldName
}
func checkDuplicatedFileName(vs []model.Field) {
fieldNameSet := make(map[string]bool)
for i := 0; i < len(vs); i++ {
if _, ex := fieldNameSet[vs[i].Name]; ex {
newName := getNewFieldName(vs[i].Name, fieldNameSet)
fieldNameSet[newName] = true
vs[i].Name = newName
} else {
fieldNameSet[vs[i].Name] = true
}
}
}