761 lines
23 KiB
Go
761 lines
23 KiB
Go
/*
|
|
* Copyright 2022 CloudWeGo Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package protobuf
|
|
|
|
import (
|
|
"fmt"
|
|
"path/filepath"
|
|
"sort"
|
|
"strings"
|
|
|
|
"github.com/cloudwego/hertz/cmd/hz/generator"
|
|
"github.com/cloudwego/hertz/cmd/hz/generator/model"
|
|
"github.com/cloudwego/hertz/cmd/hz/meta"
|
|
"github.com/cloudwego/hertz/cmd/hz/protobuf/api"
|
|
"github.com/cloudwego/hertz/cmd/hz/util"
|
|
"github.com/cloudwego/hertz/cmd/hz/util/logs"
|
|
"github.com/jhump/protoreflect/desc"
|
|
"google.golang.org/protobuf/compiler/protogen"
|
|
"google.golang.org/protobuf/proto"
|
|
"google.golang.org/protobuf/reflect/protoreflect"
|
|
"google.golang.org/protobuf/runtime/protoimpl"
|
|
"google.golang.org/protobuf/types/descriptorpb"
|
|
)
|
|
|
|
var BaseProto = descriptorpb.FileDescriptorProto{}
|
|
|
|
// getGoPackage get option go_package
|
|
// If pkgMap is specified, the specified value is used as the go_package;
|
|
// If go package is not specified, then the value of package is used as go_package.
|
|
func getGoPackage(f *descriptorpb.FileDescriptorProto, pkgMap map[string]string) string {
|
|
if f.Options == nil {
|
|
f.Options = new(descriptorpb.FileOptions)
|
|
}
|
|
if f.Options.GoPackage == nil {
|
|
f.Options.GoPackage = new(string)
|
|
}
|
|
goPkg := f.Options.GetGoPackage()
|
|
|
|
// if go_package has ";", for example go_package="/a/b/c;d", we will use "/a/b/c" as go_package
|
|
if strings.Contains(goPkg, ";") {
|
|
pkg := strings.Split(goPkg, ";")
|
|
if len(pkg) == 2 {
|
|
logs.Warnf("The go_package of the file(%s) is \"%s\", hz will use \"%s\" as the go_package.", f.GetName(), goPkg, pkg[0])
|
|
goPkg = pkg[0]
|
|
}
|
|
|
|
}
|
|
|
|
if goPkg == "" {
|
|
goPkg = f.GetPackage()
|
|
}
|
|
if opt, ok := pkgMap[f.GetName()]; ok {
|
|
return opt
|
|
}
|
|
return goPkg
|
|
}
|
|
|
|
func switchBaseType(typ descriptorpb.FieldDescriptorProto_Type) *model.Type {
|
|
switch typ {
|
|
case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE, descriptorpb.FieldDescriptorProto_TYPE_GROUP:
|
|
return nil
|
|
case descriptorpb.FieldDescriptorProto_TYPE_INT64:
|
|
return model.TypeInt64
|
|
case descriptorpb.FieldDescriptorProto_TYPE_INT32:
|
|
return model.TypeInt32
|
|
case descriptorpb.FieldDescriptorProto_TYPE_UINT64:
|
|
return model.TypeUint64
|
|
case descriptorpb.FieldDescriptorProto_TYPE_UINT32:
|
|
return model.TypeUint32
|
|
case descriptorpb.FieldDescriptorProto_TYPE_FIXED64:
|
|
return model.TypeUint64
|
|
case descriptorpb.FieldDescriptorProto_TYPE_FIXED32:
|
|
return model.TypeUint32
|
|
case descriptorpb.FieldDescriptorProto_TYPE_BOOL:
|
|
return model.TypeBool
|
|
case descriptorpb.FieldDescriptorProto_TYPE_STRING:
|
|
return model.TypeString
|
|
case descriptorpb.FieldDescriptorProto_TYPE_BYTES:
|
|
return model.TypeBinary
|
|
case descriptorpb.FieldDescriptorProto_TYPE_SFIXED32:
|
|
return model.TypeInt32
|
|
case descriptorpb.FieldDescriptorProto_TYPE_SFIXED64:
|
|
return model.TypeInt64
|
|
case descriptorpb.FieldDescriptorProto_TYPE_SINT32:
|
|
return model.TypeInt32
|
|
case descriptorpb.FieldDescriptorProto_TYPE_SINT64:
|
|
return model.TypeInt64
|
|
case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE:
|
|
return model.TypeFloat64
|
|
case descriptorpb.FieldDescriptorProto_TYPE_FLOAT:
|
|
return model.TypeFloat32
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func astToService(ast *descriptorpb.FileDescriptorProto, resolver *Resolver, cmdType string, gen *protogen.Plugin) ([]*generator.Service, error) {
|
|
resolver.ExportReferred(true, false)
|
|
ss := ast.GetService()
|
|
out := make([]*generator.Service, 0, len(ss))
|
|
var merges model.Models
|
|
|
|
for _, s := range ss {
|
|
service := &generator.Service{
|
|
Name: s.GetName(),
|
|
}
|
|
|
|
service.BaseDomain = ""
|
|
domainAnno := getCompatibleAnnotation(s.GetOptions(), api.E_BaseDomain, api.E_BaseDomainCompatible)
|
|
if cmdType == meta.CmdClient {
|
|
val, ok := domainAnno.(string)
|
|
if ok && len(val) != 0 {
|
|
service.BaseDomain = val
|
|
}
|
|
}
|
|
|
|
ms := s.GetMethod()
|
|
methods := make([]*generator.HttpMethod, 0, len(ms))
|
|
clientMethods := make([]*generator.ClientMethod, 0, len(ms))
|
|
servicePathAnno := checkFirstOption(api.E_ServicePath, s.GetOptions())
|
|
servicePath := ""
|
|
if val, ok := servicePathAnno.(string); ok {
|
|
servicePath = val
|
|
}
|
|
for _, m := range ms {
|
|
rs := getAllOptions(HttpMethodOptions, m.GetOptions())
|
|
if len(rs) == 0 {
|
|
continue
|
|
}
|
|
httpOpts := httpOptions{}
|
|
for k, v := range rs {
|
|
httpOpts = append(httpOpts, httpOption{
|
|
method: k,
|
|
path: v.(string),
|
|
})
|
|
}
|
|
// turn the map into a slice and sort it to make sure getting the results in the same order every time
|
|
sort.Sort(httpOpts)
|
|
|
|
var handlerOutDir string
|
|
genPath := getCompatibleAnnotation(m.GetOptions(), api.E_HandlerPath, api.E_HandlerPathCompatible)
|
|
handlerOutDir, ok := genPath.(string)
|
|
if !ok || len(handlerOutDir) == 0 {
|
|
handlerOutDir = ""
|
|
}
|
|
if len(handlerOutDir) == 0 {
|
|
handlerOutDir = servicePath
|
|
}
|
|
|
|
// protoGoInfo can get generated "Go Info" for proto file.
|
|
// the type name may be different between "***.proto" and "***.pb.go"
|
|
protoGoInfo, exist := gen.FilesByPath[ast.GetName()]
|
|
if !exist {
|
|
return nil, fmt.Errorf("file(%s) can not exist", ast.GetName())
|
|
}
|
|
methodGoInfo, err := getMethod(protoGoInfo, m)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
inputGoType := methodGoInfo.Input
|
|
outputGoType := methodGoInfo.Output
|
|
|
|
reqName := m.GetInputType()
|
|
sb, err := resolver.ResolveIdentifier(reqName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
reqName = util.BaseName(sb.Scope.GetOptions().GetGoPackage(), "") + "." + inputGoType.GoIdent.GoName
|
|
reqRawName := inputGoType.GoIdent.GoName
|
|
reqPackage := util.BaseName(sb.Scope.GetOptions().GetGoPackage(), "")
|
|
respName := m.GetOutputType()
|
|
st, err := resolver.ResolveIdentifier(respName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
respName = util.BaseName(st.Scope.GetOptions().GetGoPackage(), "") + "." + outputGoType.GoIdent.GoName
|
|
respRawName := outputGoType.GoIdent.GoName
|
|
respPackage := util.BaseName(sb.Scope.GetOptions().GetGoPackage(), "")
|
|
|
|
var serializer string
|
|
sl, sv := checkFirstOptions(SerializerOptions, m.GetOptions())
|
|
if sl != "" {
|
|
serializer = sv.(string)
|
|
}
|
|
|
|
method := &generator.HttpMethod{
|
|
Name: util.CamelString(m.GetName()),
|
|
HTTPMethod: httpOpts[0].method,
|
|
Path: httpOpts[0].path,
|
|
Serializer: serializer,
|
|
OutputDir: handlerOutDir,
|
|
GenHandler: true,
|
|
}
|
|
|
|
goOptMapAlias := make(map[string]string, 1)
|
|
refs := resolver.ExportReferred(false, true)
|
|
method.Models = make(map[string]*model.Model, len(refs))
|
|
for _, ref := range refs {
|
|
if val, exist := method.Models[ref.Model.PackageName]; exist {
|
|
if val.Package == ref.Model.Package {
|
|
method.Models[ref.Model.PackageName] = ref.Model
|
|
goOptMapAlias[ref.Model.Package] = ref.Model.PackageName
|
|
} else {
|
|
file := filepath.Base(ref.Model.FilePath)
|
|
fileName := strings.Split(file, ".")
|
|
newPkg := fileName[len(fileName)-2] + "_" + val.PackageName
|
|
method.Models[newPkg] = ref.Model
|
|
goOptMapAlias[ref.Model.Package] = newPkg
|
|
}
|
|
continue
|
|
}
|
|
method.Models[ref.Model.PackageName] = ref.Model
|
|
goOptMapAlias[ref.Model.Package] = ref.Model.PackageName
|
|
}
|
|
merges = service.Models
|
|
merges.MergeMap(method.Models)
|
|
if goOptMapAlias[sb.Scope.GetOptions().GetGoPackage()] != "" {
|
|
reqName = goOptMapAlias[sb.Scope.GetOptions().GetGoPackage()] + "." + inputGoType.GoIdent.GoName
|
|
}
|
|
if goOptMapAlias[sb.Scope.GetOptions().GetGoPackage()] != "" {
|
|
respName = goOptMapAlias[st.Scope.GetOptions().GetGoPackage()] + "." + outputGoType.GoIdent.GoName
|
|
}
|
|
method.RequestTypeName = reqName
|
|
method.RequestTypeRawName = reqRawName
|
|
method.RequestTypePackage = reqPackage
|
|
method.ReturnTypeName = respName
|
|
method.ReturnTypeRawName = respRawName
|
|
method.ReturnTypePackage = respPackage
|
|
|
|
methods = append(methods, method)
|
|
for idx, anno := range httpOpts {
|
|
if idx == 0 {
|
|
continue
|
|
}
|
|
tmp := *method
|
|
tmp.HTTPMethod = anno.method
|
|
tmp.Path = anno.path
|
|
tmp.GenHandler = false
|
|
methods = append(methods, &tmp)
|
|
}
|
|
|
|
if cmdType == meta.CmdClient {
|
|
clientMethod := &generator.ClientMethod{}
|
|
clientMethod.HttpMethod = method
|
|
err := parseAnnotationToClient(clientMethod, gen, ast, m)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
clientMethods = append(clientMethods, clientMethod)
|
|
}
|
|
}
|
|
|
|
service.ClientMethods = clientMethods
|
|
service.Methods = methods
|
|
service.Models = merges
|
|
out = append(out, service)
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func getCompatibleAnnotation(options proto.Message, anno, compatibleAnno *protoimpl.ExtensionInfo) interface{} {
|
|
if proto.HasExtension(options, anno) {
|
|
return checkFirstOption(anno, options)
|
|
} else if proto.HasExtension(options, compatibleAnno) {
|
|
return checkFirstOption(compatibleAnno, options)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func parseAnnotationToClient(clientMethod *generator.ClientMethod, gen *protogen.Plugin, ast *descriptorpb.FileDescriptorProto, m *descriptorpb.MethodDescriptorProto) error {
|
|
file, exist := gen.FilesByPath[ast.GetName()]
|
|
if !exist {
|
|
return fmt.Errorf("file(%s) can not exist", ast.GetName())
|
|
}
|
|
method, err := getMethod(file, m)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// pb input type must be message
|
|
inputType := method.Input
|
|
var (
|
|
hasBodyAnnotation bool
|
|
hasFormAnnotation bool
|
|
)
|
|
for _, f := range inputType.Fields {
|
|
hasAnnotation := false
|
|
isStringFieldType := false
|
|
if f.Desc.Kind() == protoreflect.StringKind {
|
|
isStringFieldType = true
|
|
}
|
|
if proto.HasExtension(f.Desc.Options(), api.E_Query) {
|
|
hasAnnotation = true
|
|
queryAnnos := proto.GetExtension(f.Desc.Options(), api.E_Query)
|
|
val := checkSnakeName(queryAnnos.(string))
|
|
clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName)
|
|
}
|
|
|
|
if proto.HasExtension(f.Desc.Options(), api.E_Path) {
|
|
hasAnnotation = true
|
|
pathAnnos := proto.GetExtension(f.Desc.Options(), api.E_Path)
|
|
val := checkSnakeName(pathAnnos.(string))
|
|
if isStringFieldType {
|
|
clientMethod.PathParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName)
|
|
} else {
|
|
clientMethod.PathParamsCode += fmt.Sprintf("%q: fmt.Sprint(req.Get%s()),\n", val, f.GoName)
|
|
}
|
|
}
|
|
|
|
if proto.HasExtension(f.Desc.Options(), api.E_Header) {
|
|
hasAnnotation = true
|
|
headerAnnos := proto.GetExtension(f.Desc.Options(), api.E_Header)
|
|
val := checkSnakeName(headerAnnos.(string))
|
|
if isStringFieldType {
|
|
clientMethod.HeaderParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName)
|
|
} else {
|
|
clientMethod.HeaderParamsCode += fmt.Sprintf("%q: fmt.Sprint(req.Get%s()),\n", val, f.GoName)
|
|
}
|
|
}
|
|
|
|
if formAnnos := getCompatibleAnnotation(f.Desc.Options(), api.E_Form, api.E_FormCompatible); formAnnos != nil {
|
|
hasAnnotation = true
|
|
hasFormAnnotation = true
|
|
val := checkSnakeName(formAnnos.(string))
|
|
if isStringFieldType {
|
|
clientMethod.FormValueCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName)
|
|
} else {
|
|
clientMethod.FormValueCode += fmt.Sprintf("%q: fmt.Sprint(req.Get%s()),\n", val, f.GoName)
|
|
}
|
|
}
|
|
|
|
if proto.HasExtension(f.Desc.Options(), api.E_Body) {
|
|
hasAnnotation = true
|
|
hasBodyAnnotation = true
|
|
}
|
|
|
|
if fileAnnos := getCompatibleAnnotation(f.Desc.Options(), api.E_FileName, api.E_FileNameCompatible); fileAnnos != nil {
|
|
hasAnnotation = true
|
|
hasFormAnnotation = true
|
|
val := checkSnakeName(fileAnnos.(string))
|
|
clientMethod.FormFileCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName)
|
|
}
|
|
if !hasAnnotation && strings.EqualFold(clientMethod.HTTPMethod, "get") {
|
|
clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", checkSnakeName(string(f.Desc.Name())), f.GoName)
|
|
}
|
|
}
|
|
clientMethod.BodyParamsCode = meta.SetBodyParam
|
|
if hasBodyAnnotation && hasFormAnnotation {
|
|
clientMethod.FormValueCode = ""
|
|
clientMethod.FormFileCode = ""
|
|
}
|
|
if !hasBodyAnnotation && hasFormAnnotation {
|
|
clientMethod.BodyParamsCode = ""
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func getMethod(file *protogen.File, m *descriptorpb.MethodDescriptorProto) (*protogen.Method, error) {
|
|
for _, f := range file.Services {
|
|
for _, method := range f.Methods {
|
|
if string(method.Desc.Name()) == m.GetName() {
|
|
return method, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil, fmt.Errorf("can not find method: %s", m.GetName())
|
|
}
|
|
|
|
//---------------------------------Model--------------------------------
|
|
|
|
func astToModel(ast *descriptorpb.FileDescriptorProto, rs *Resolver) (*model.Model, error) {
|
|
main := rs.mainPkg.Model
|
|
if main == nil {
|
|
main = new(model.Model)
|
|
}
|
|
|
|
mainFileDes := rs.files.PbReflect[ast.GetName()]
|
|
isProto3 := mainFileDes.IsProto3()
|
|
// Enums
|
|
ems := ast.GetEnumType()
|
|
enums := make([]model.Enum, 0, len(ems))
|
|
for _, e := range ems {
|
|
em := model.Enum{
|
|
Scope: main,
|
|
Name: e.GetName(),
|
|
GoType: "int32",
|
|
}
|
|
es := e.GetValue()
|
|
vs := make([]model.Constant, 0, len(es))
|
|
for _, ee := range es {
|
|
vs = append(vs, model.Constant{
|
|
Scope: main,
|
|
Name: ee.GetName(),
|
|
Type: model.TypeInt32,
|
|
Value: model.IntExpression{Src: int(ee.GetNumber())},
|
|
})
|
|
}
|
|
em.Values = vs
|
|
enums = append(enums, em)
|
|
}
|
|
main.Enums = enums
|
|
|
|
// Structs
|
|
sts := ast.GetMessageType()
|
|
structs := make([]model.Struct, 0, len(sts)*2)
|
|
oneofs := make([]model.Oneof, 0, 1)
|
|
for _, st := range sts {
|
|
stMessage := mainFileDes.FindMessage(ast.GetPackage() + "." + st.GetName())
|
|
stLeadingComments := getMessageLeadingComments(stMessage)
|
|
s := model.Struct{
|
|
Scope: main,
|
|
Name: st.GetName(),
|
|
Category: model.CategoryStruct,
|
|
LeadingComments: stLeadingComments,
|
|
}
|
|
|
|
ns := st.GetNestedType()
|
|
nestedMessageInfoMap := getNestedMessageInfoMap(stMessage)
|
|
for _, nt := range ns {
|
|
if IsMapEntry(nt) {
|
|
continue
|
|
}
|
|
|
|
nestedMessageInfo := nestedMessageInfoMap[nt.GetName()]
|
|
nestedMessageLeadingComment := getMessageLeadingComments(nestedMessageInfo)
|
|
s := model.Struct{
|
|
Scope: main,
|
|
Name: st.GetName() + "_" + nt.GetName(),
|
|
Category: model.CategoryStruct,
|
|
LeadingComments: nestedMessageLeadingComment,
|
|
}
|
|
fs := nt.GetField()
|
|
ns := nt.GetNestedType()
|
|
vs := make([]model.Field, 0, len(fs))
|
|
|
|
oneofMap := make(map[string]model.Field)
|
|
oneofType, err := resolveOneof(nestedMessageInfo, oneofMap, rs, isProto3, s, ns)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
oneofs = append(oneofs, oneofType...)
|
|
|
|
choiceSet := make(map[string]bool)
|
|
|
|
for _, f := range fs {
|
|
if field, exist := oneofMap[f.GetName()]; exist {
|
|
if _, ex := choiceSet[field.Name]; !ex {
|
|
choiceSet[field.Name] = true
|
|
vs = append(vs, field)
|
|
}
|
|
continue
|
|
}
|
|
dv := f.GetDefaultValue()
|
|
fieldLeadingComments, fieldTrailingComments := getFiledComments(f, nestedMessageInfo)
|
|
t, err := rs.ResolveType(f, ns)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
field := model.Field{
|
|
Scope: &s,
|
|
Name: util.CamelString(f.GetName()),
|
|
Type: t,
|
|
LeadingComments: fieldLeadingComments,
|
|
TrailingComments: fieldTrailingComments,
|
|
IsPointer: isPointer(f, isProto3),
|
|
}
|
|
if dv != "" {
|
|
field.IsSetDefault = true
|
|
field.DefaultValue, err = parseDefaultValue(f.GetType(), f.GetDefaultValue())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
err = injectTagsToModel(f, &field, true)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
vs = append(vs, field)
|
|
}
|
|
checkDuplicatedFileName(vs)
|
|
s.Fields = vs
|
|
structs = append(structs, s)
|
|
}
|
|
|
|
fs := st.GetField()
|
|
vs := make([]model.Field, 0, len(fs))
|
|
|
|
oneofMap := make(map[string]model.Field)
|
|
oneofType, err := resolveOneof(stMessage, oneofMap, rs, isProto3, s, ns)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
oneofs = append(oneofs, oneofType...)
|
|
|
|
choiceSet := make(map[string]bool)
|
|
|
|
for _, f := range fs {
|
|
if field, exist := oneofMap[f.GetName()]; exist {
|
|
if _, ex := choiceSet[field.Name]; !ex {
|
|
choiceSet[field.Name] = true
|
|
vs = append(vs, field)
|
|
}
|
|
continue
|
|
}
|
|
dv := f.GetDefaultValue()
|
|
fieldLeadingComments, fieldTrailingComments := getFiledComments(f, stMessage)
|
|
t, err := rs.ResolveType(f, ns)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
field := model.Field{
|
|
Scope: &s,
|
|
Name: util.CamelString(f.GetName()),
|
|
Type: t,
|
|
LeadingComments: fieldLeadingComments,
|
|
TrailingComments: fieldTrailingComments,
|
|
IsPointer: isPointer(f, isProto3),
|
|
}
|
|
if dv != "" {
|
|
field.IsSetDefault = true
|
|
field.DefaultValue, err = parseDefaultValue(f.GetType(), f.GetDefaultValue())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
err = injectTagsToModel(f, &field, true)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
vs = append(vs, field)
|
|
}
|
|
checkDuplicatedFileName(vs)
|
|
s.Fields = vs
|
|
structs = append(structs, s)
|
|
|
|
}
|
|
main.Oneofs = oneofs
|
|
main.Structs = structs
|
|
|
|
// In case of only the service refers another model, therefore scanning service is necessary
|
|
ss := ast.GetService()
|
|
for _, s := range ss {
|
|
ms := s.GetMethod()
|
|
for _, m := range ms {
|
|
_, err := rs.ResolveIdentifier(m.GetInputType())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
_, err = rs.ResolveIdentifier(m.GetOutputType())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
|
|
return main, nil
|
|
}
|
|
|
|
// getMessageLeadingComments can get struct LeadingComment
|
|
func getMessageLeadingComments(stMessage *desc.MessageDescriptor) string {
|
|
if stMessage == nil {
|
|
return ""
|
|
}
|
|
stComments := stMessage.GetSourceInfo().GetLeadingComments()
|
|
stComments = formatComments(stComments)
|
|
|
|
return stComments
|
|
}
|
|
|
|
// getFiledComments can get field LeadingComments and field TailingComments for field
|
|
func getFiledComments(f *descriptorpb.FieldDescriptorProto, stMessage *desc.MessageDescriptor) (string, string) {
|
|
if stMessage == nil {
|
|
return "", ""
|
|
}
|
|
|
|
fieldNum := f.GetNumber()
|
|
field := stMessage.FindFieldByNumber(fieldNum)
|
|
fieldInfo := field.GetSourceInfo()
|
|
|
|
fieldLeadingComments := fieldInfo.GetLeadingComments()
|
|
fieldTailingComments := fieldInfo.GetTrailingComments()
|
|
|
|
fieldLeadingComments = formatComments(fieldLeadingComments)
|
|
fieldTailingComments = formatComments(fieldTailingComments)
|
|
|
|
return fieldLeadingComments, fieldTailingComments
|
|
}
|
|
|
|
// formatComments can format the comments for beauty
|
|
func formatComments(comments string) string {
|
|
if len(comments) == 0 {
|
|
return ""
|
|
}
|
|
|
|
comments = util.TrimLastChar(comments)
|
|
comments = util.AddSlashForComments(comments)
|
|
|
|
return comments
|
|
}
|
|
|
|
// getNestedMessageInfoMap can get all nested struct
|
|
func getNestedMessageInfoMap(stMessage *desc.MessageDescriptor) map[string]*desc.MessageDescriptor {
|
|
nestedMessage := stMessage.GetNestedMessageTypes()
|
|
nestedMessageInfoMap := make(map[string]*desc.MessageDescriptor, len(nestedMessage))
|
|
|
|
for _, nestedMsg := range nestedMessage {
|
|
nestedMsgName := nestedMsg.GetName()
|
|
nestedMessageInfoMap[nestedMsgName] = nestedMsg
|
|
}
|
|
|
|
return nestedMessageInfoMap
|
|
}
|
|
|
|
func parseDefaultValue(typ descriptorpb.FieldDescriptorProto_Type, val string) (model.Literal, error) {
|
|
switch typ {
|
|
case descriptorpb.FieldDescriptorProto_TYPE_BYTES, descriptorpb.FieldDescriptorProto_TYPE_STRING:
|
|
return model.StringExpression{Src: val}, nil
|
|
case descriptorpb.FieldDescriptorProto_TYPE_BOOL:
|
|
return model.BoolExpression{Src: val == "true"}, nil
|
|
case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE,
|
|
descriptorpb.FieldDescriptorProto_TYPE_FLOAT,
|
|
descriptorpb.FieldDescriptorProto_TYPE_INT64,
|
|
descriptorpb.FieldDescriptorProto_TYPE_UINT64,
|
|
descriptorpb.FieldDescriptorProto_TYPE_INT32,
|
|
descriptorpb.FieldDescriptorProto_TYPE_FIXED64,
|
|
descriptorpb.FieldDescriptorProto_TYPE_FIXED32,
|
|
descriptorpb.FieldDescriptorProto_TYPE_UINT32,
|
|
descriptorpb.FieldDescriptorProto_TYPE_ENUM,
|
|
descriptorpb.FieldDescriptorProto_TYPE_SFIXED32,
|
|
descriptorpb.FieldDescriptorProto_TYPE_SFIXED64,
|
|
descriptorpb.FieldDescriptorProto_TYPE_SINT32,
|
|
descriptorpb.FieldDescriptorProto_TYPE_SINT64:
|
|
return model.NumberExpression{Src: val}, nil
|
|
default:
|
|
return nil, fmt.Errorf("unsupported type %s", typ.String())
|
|
}
|
|
}
|
|
|
|
func isPointer(f *descriptorpb.FieldDescriptorProto, isProto3 bool) bool {
|
|
if f.GetType() == descriptorpb.FieldDescriptorProto_TYPE_MESSAGE || f.GetType() == descriptorpb.FieldDescriptorProto_TYPE_BYTES {
|
|
return false
|
|
}
|
|
|
|
if !isProto3 {
|
|
if f.GetLabel() == descriptorpb.FieldDescriptorProto_LABEL_REPEATED {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
switch f.GetLabel() {
|
|
case descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL:
|
|
if !f.GetProto3Optional() {
|
|
return false
|
|
}
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func resolveOneof(stMessage *desc.MessageDescriptor, oneofMap map[string]model.Field, rs *Resolver, isProto3 bool, s model.Struct, ns []*descriptorpb.DescriptorProto) ([]model.Oneof, error) {
|
|
oneofs := make([]model.Oneof, 0, 1)
|
|
if len(stMessage.GetOneOfs()) != 0 {
|
|
for _, oneof := range stMessage.GetOneOfs() {
|
|
if isProto3 {
|
|
if oneof.IsSynthetic() {
|
|
continue
|
|
}
|
|
}
|
|
oneofName := oneof.GetName()
|
|
messageName := s.Name
|
|
typeName := "is" + messageName + "_" + oneofName
|
|
field := model.Field{
|
|
Scope: &s,
|
|
Name: util.CamelString(oneofName),
|
|
Type: model.NewOneofType(typeName),
|
|
IsPointer: false,
|
|
}
|
|
|
|
oneofComment := oneof.GetSourceInfo().GetLeadingComments()
|
|
oneofComment = formatComments(oneofComment)
|
|
var oneofLeadingComments string
|
|
if oneofComment == "" {
|
|
oneofLeadingComments = fmt.Sprintf(" Types that are assignable to %s:\n", oneofName)
|
|
} else {
|
|
oneofLeadingComments = fmt.Sprintf("%s\n//\n// Types that are assignable to %s:\n", oneofComment, oneofName)
|
|
}
|
|
for idx, ch := range oneof.GetChoices() {
|
|
if idx == len(oneof.GetChoices())-1 {
|
|
oneofLeadingComments = oneofLeadingComments + fmt.Sprintf("// *%s_%s", messageName, ch.GetName())
|
|
} else {
|
|
oneofLeadingComments = oneofLeadingComments + fmt.Sprintf("// *%s_%s\n", messageName, ch.GetName())
|
|
}
|
|
}
|
|
field.LeadingComments = oneofLeadingComments
|
|
|
|
choices := make([]model.Choice, 0, len(oneof.GetChoices()))
|
|
for _, ch := range oneof.GetChoices() {
|
|
t, err := rs.ResolveType(ch.AsFieldDescriptorProto(), ns)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
choice := model.Choice{
|
|
MessageName: messageName,
|
|
ChoiceName: ch.GetName(),
|
|
Type: t,
|
|
}
|
|
choices = append(choices, choice)
|
|
oneofMap[ch.GetName()] = field
|
|
}
|
|
|
|
oneofType := model.Oneof{
|
|
MessageName: messageName,
|
|
OneofName: oneofName,
|
|
InterfaceName: typeName,
|
|
Choices: choices,
|
|
}
|
|
|
|
oneofs = append(oneofs, oneofType)
|
|
}
|
|
}
|
|
return oneofs, nil
|
|
}
|
|
|
|
func getNewFieldName(fieldName string, fieldNameSet map[string]bool) string {
|
|
if _, ex := fieldNameSet[fieldName]; ex {
|
|
fieldName = fieldName + "_"
|
|
return getNewFieldName(fieldName, fieldNameSet)
|
|
}
|
|
return fieldName
|
|
}
|
|
|
|
func checkDuplicatedFileName(vs []model.Field) {
|
|
fieldNameSet := make(map[string]bool)
|
|
for i := 0; i < len(vs); i++ {
|
|
if _, ex := fieldNameSet[vs[i].Name]; ex {
|
|
newName := getNewFieldName(vs[i].Name, fieldNameSet)
|
|
fieldNameSet[newName] = true
|
|
vs[i].Name = newName
|
|
} else {
|
|
fieldNameSet[vs[i].Name] = true
|
|
}
|
|
}
|
|
}
|