commit 51c1a6b5981232db4ea2e773ea576162df1bfcc2 Author: xuyang Date: Tue Apr 30 19:30:09 2024 +0800 register改良 diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/hz.iml b/.idea/hz.iml new file mode 100644 index 0000000..5e764c4 --- /dev/null +++ b/.idea/hz.iml @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..d4ba7c1 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..b2bdec2 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/app/app.go b/app/app.go new file mode 100644 index 0000000..2313bb8 --- /dev/null +++ b/app/app.go @@ -0,0 +1,430 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package app + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/cloudwego/hertz/cmd/hz/config" + "github.com/cloudwego/hertz/cmd/hz/generator" + "github.com/cloudwego/hertz/cmd/hz/meta" + "github.com/cloudwego/hertz/cmd/hz/protobuf" + "github.com/cloudwego/hertz/cmd/hz/thrift" + "github.com/cloudwego/hertz/cmd/hz/util" + "github.com/cloudwego/hertz/cmd/hz/util/logs" + "github.com/urfave/cli/v2" +) + +// global args. MUST fork it when use +var globalArgs = config.NewArgument() + +func New(c *cli.Context) error { + args, err := globalArgs.Parse(c, meta.CmdNew) + if err != nil { + return cli.Exit(err, meta.LoadError) + } + setLogVerbose(args.Verbose) + logs.Debugf("args: %#v\n", args) + + exist, err := util.PathExist(filepath.Join(args.OutDir, meta.ManifestFile)) + if err != nil { + return cli.Exit(err, meta.LoadError) + } + + if exist && !args.ForceNew { + return cli.Exit(fmt.Errorf("the current is already a hertz project, if you want to regenerate it you can specify \"-force\""), meta.LoadError) + } + + err = GenerateLayout(args) + if err != nil { + return cli.Exit(err, meta.GenerateLayoutError) + } + + err = TriggerPlugin(args) + if err != nil { + return cli.Exit(err, meta.PluginError) + } + // ".hz" file converges to the hz tool + manifest := new(meta.Manifest) + args.InitManifest(manifest) + err = manifest.Persist(args.OutDir) + if err != nil { + return cli.Exit(fmt.Errorf("persist manifest failed: %v", err), meta.PersistError) + } + if !args.NeedGoMod && args.IdlType == meta.IdlThrift { + logs.Warn(meta.AddThriftReplace) + } + + return nil +} + +func Update(c *cli.Context) error { + // begin to update + args, err := globalArgs.Parse(c, meta.CmdUpdate) + if err != nil { + return cli.Exit(err, meta.LoadError) + } + setLogVerbose(args.Verbose) + logs.Debugf("Args: %#v\n", args) + + manifest := new(meta.Manifest) + err = manifest.InitAndValidate(args.OutDir) + if err != nil { + return cli.Exit(err, meta.LoadError) + } + // update argument by ".hz", can automatically get "handler_dir"/"model_dir"/"router_dir" + args.UpdateByManifest(manifest) + + err = TriggerPlugin(args) + if err != nil { + return cli.Exit(err, meta.PluginError) + } + // If the "handler_dir"/"model_dir" is updated, write it back to ".hz" + args.UpdateManifest(manifest) + err = manifest.Persist(args.OutDir) + if err != nil { + return cli.Exit(fmt.Errorf("persist manifest failed: %v", err), meta.PersistError) + } + + return nil +} + +func Model(c *cli.Context) error { + args, err := globalArgs.Parse(c, meta.CmdModel) + if err != nil { + return cli.Exit(err, meta.LoadError) + } + setLogVerbose(args.Verbose) + logs.Debugf("Args: %#v\n", args) + + err = TriggerPlugin(args) + if err != nil { + return cli.Exit(err, meta.PluginError) + } + + return nil +} + +func Client(c *cli.Context) error { + args, err := globalArgs.Parse(c, meta.CmdClient) + if err != nil { + return cli.Exit(err, meta.LoadError) + } + setLogVerbose(args.Verbose) + logs.Debugf("Args: %#v\n", args) + + err = TriggerPlugin(args) + if err != nil { + return cli.Exit(err, meta.PluginError) + } + + return nil +} + +func PluginMode() { + mode := os.Getenv(meta.EnvPluginMode) + if len(os.Args) <= 1 && mode != "" { + switch mode { + case meta.ThriftPluginName: + plugin := new(thrift.Plugin) + os.Exit(plugin.Run()) + case meta.ProtocPluginName: + plugin := new(protobuf.Plugin) + os.Exit(plugin.Run()) + } + } +} + +func Init() *cli.App { + // flags + verboseFlag := cli.BoolFlag{Name: "verbose,vv", Usage: "turn on verbose mode", Destination: &globalArgs.Verbose} + + idlFlag := cli.StringSliceFlag{Name: "idl", Usage: "Specify the IDL file path. (.thrift or .proto)"} + moduleFlag := cli.StringFlag{Name: "module", Aliases: []string{"mod"}, Usage: "Specify the Go module name.", Destination: &globalArgs.Gomod} + serviceNameFlag := cli.StringFlag{Name: "service", Usage: "Specify the service name.", Destination: &globalArgs.ServiceName} + outDirFlag := cli.StringFlag{Name: "out_dir", Usage: "Specify the project path.", Destination: &globalArgs.OutDir} + handlerDirFlag := cli.StringFlag{Name: "handler_dir", Usage: "Specify the handler relative path (based on \"out_dir\").", Destination: &globalArgs.HandlerDir} + modelDirFlag := cli.StringFlag{Name: "model_dir", Usage: "Specify the model relative path (based on \"out_dir\").", Destination: &globalArgs.ModelDir} + routerDirFlag := cli.StringFlag{Name: "router_dir", Usage: "Specify the router relative path (based on \"out_dir\").", Destination: &globalArgs.RouterDir} + useFlag := cli.StringFlag{Name: "use", Usage: "Specify the model package to import for handler.", Destination: &globalArgs.Use} + baseDomainFlag := cli.StringFlag{Name: "base_domain", Usage: "Specify the request domain.", Destination: &globalArgs.BaseDomain} + clientDirFlag := cli.StringFlag{Name: "client_dir", Usage: "Specify the client path. If not specified, IDL generated path is used for 'client' command; no client code is generated for 'new' command", Destination: &globalArgs.ClientDir} + forceClientDirFlag := cli.StringFlag{Name: "force_client_dir", Usage: "Specify the client path, and won't use namespaces as subpaths", Destination: &globalArgs.ForceClientDir} + + optPkgFlag := cli.StringSliceFlag{Name: "option_package", Aliases: []string{"P"}, Usage: "Specify the package path. ({include_path}={import_path})"} + includesFlag := cli.StringSliceFlag{Name: "proto_path", Aliases: []string{"I"}, Usage: "Add an IDL search path for includes. (Valid only if idl is protobuf)"} + excludeFilesFlag := cli.StringSliceFlag{Name: "exclude_file", Aliases: []string{"E"}, Usage: "Specify the files that do not need to be updated."} + thriftOptionsFlag := cli.StringSliceFlag{Name: "thriftgo", Aliases: []string{"t"}, Usage: "Specify arguments for the thriftgo. ({flag}={value})"} + protoOptionsFlag := cli.StringSliceFlag{Name: "protoc", Aliases: []string{"p"}, Usage: "Specify arguments for the protoc. ({flag}={value})"} + thriftPluginsFlag := cli.StringSliceFlag{Name: "thrift-plugins", Usage: "Specify plugins for the thriftgo. ({plugin_name}:{options})"} + protoPluginsFlag := cli.StringSliceFlag{Name: "protoc-plugins", Usage: "Specify plugins for the protoc. ({plugin_name}:{options}:{out_dir})"} + noRecurseFlag := cli.BoolFlag{Name: "no_recurse", Usage: "Generate master model only.", Destination: &globalArgs.NoRecurse} + forceNewFlag := cli.BoolFlag{Name: "force", Aliases: []string{"f"}, Usage: "Force new a project, which will overwrite the generated files", Destination: &globalArgs.ForceNew} + enableExtendsFlag := cli.BoolFlag{Name: "enable_extends", Usage: "Parse 'extends' for thrift IDL", Destination: &globalArgs.EnableExtends} + + jsonEnumStrFlag := cli.BoolFlag{Name: "json_enumstr", Usage: "Use string instead of num for json enums when idl is thrift.", Destination: &globalArgs.JSONEnumStr} + unsetOmitemptyFlag := cli.BoolFlag{Name: "unset_omitempty", Usage: "Remove 'omitempty' tag for generated struct.", Destination: &globalArgs.UnsetOmitempty} + protoCamelJSONTag := cli.BoolFlag{Name: "pb_camel_json_tag", Usage: "Convert Name style for json tag to camel(Only works protobuf).", Destination: &globalArgs.ProtobufCamelJSONTag} + snakeNameFlag := cli.BoolFlag{Name: "snake_tag", Usage: "Use snake_case style naming for tags. (Only works for 'form', 'query', 'json')", Destination: &globalArgs.SnakeName} + rmTagFlag := cli.StringSliceFlag{Name: "rm_tag", Usage: "Remove the default tag(json/query/form). If the annotation tag is set explicitly, it will not be removed."} + customLayout := cli.StringFlag{Name: "customize_layout", Usage: "Specify the path for layout template.", Destination: &globalArgs.CustomizeLayout} + customLayoutData := cli.StringFlag{Name: "customize_layout_data_path", Usage: "Specify the path for layout template render data.", Destination: &globalArgs.CustomizeLayoutData} + customPackage := cli.StringFlag{Name: "customize_package", Usage: "Specify the path for package template.", Destination: &globalArgs.CustomizePackage} + handlerByMethod := cli.BoolFlag{Name: "handler_by_method", Usage: "Generate a separate handler file for each method.", Destination: &globalArgs.HandlerByMethod} + + // app + app := cli.NewApp() + app.Name = "hz" + app.Usage = "A idl parser and code generator for Hertz projects" + app.Version = meta.Version + // The default separator for multiple parameters is modified to ";" + app.SliceFlagSeparator = ";" + + // global flags + app.Flags = []cli.Flag{ + &verboseFlag, + } + + // Commands + app.Commands = []*cli.Command{ + { + Name: meta.CmdNew, + Usage: "Generate a new Hertz project", + Flags: []cli.Flag{ + &idlFlag, + &serviceNameFlag, + &moduleFlag, + &outDirFlag, + &handlerDirFlag, + &modelDirFlag, + &routerDirFlag, + &clientDirFlag, + &useFlag, + + &includesFlag, + &thriftOptionsFlag, + &protoOptionsFlag, + &optPkgFlag, + &noRecurseFlag, + &forceNewFlag, + &enableExtendsFlag, + + &jsonEnumStrFlag, + &unsetOmitemptyFlag, + &protoCamelJSONTag, + &snakeNameFlag, + &rmTagFlag, + &excludeFilesFlag, + &customLayout, + &customLayoutData, + &customPackage, + &handlerByMethod, + &protoPluginsFlag, + &thriftPluginsFlag, + }, + Action: New, + }, + { + Name: meta.CmdUpdate, + Usage: "Update an existing Hertz project", + Flags: []cli.Flag{ + &idlFlag, + &moduleFlag, + &outDirFlag, + &handlerDirFlag, + &modelDirFlag, + &clientDirFlag, + &useFlag, + + &includesFlag, + &thriftOptionsFlag, + &protoOptionsFlag, + &optPkgFlag, + &noRecurseFlag, + &enableExtendsFlag, + + &jsonEnumStrFlag, + &unsetOmitemptyFlag, + &protoCamelJSONTag, + &snakeNameFlag, + &rmTagFlag, + &excludeFilesFlag, + &customPackage, + &handlerByMethod, + &protoPluginsFlag, + &thriftPluginsFlag, + }, + Action: Update, + }, + { + Name: meta.CmdModel, + Usage: "Generate model code only", + Flags: []cli.Flag{ + &idlFlag, + &moduleFlag, + &outDirFlag, + &modelDirFlag, + + &includesFlag, + &thriftOptionsFlag, + &protoOptionsFlag, + &noRecurseFlag, + + &jsonEnumStrFlag, + &unsetOmitemptyFlag, + &protoCamelJSONTag, + &snakeNameFlag, + &rmTagFlag, + &excludeFilesFlag, + }, + Action: Model, + }, + { + Name: meta.CmdClient, + Usage: "Generate hertz client based on IDL", + Flags: []cli.Flag{ + &idlFlag, + &moduleFlag, + &baseDomainFlag, + &modelDirFlag, + &clientDirFlag, + &useFlag, + &forceClientDirFlag, + + &includesFlag, + &thriftOptionsFlag, + &protoOptionsFlag, + &noRecurseFlag, + &enableExtendsFlag, + + &jsonEnumStrFlag, + &unsetOmitemptyFlag, + &protoCamelJSONTag, + &snakeNameFlag, + &rmTagFlag, + &excludeFilesFlag, + &customPackage, + &protoPluginsFlag, + &thriftPluginsFlag, + }, + Action: Client, + }, + } + return app +} + +func setLogVerbose(verbose bool) { + if verbose { + logs.SetLevel(logs.LevelDebug) + } else { + logs.SetLevel(logs.LevelWarn) + } +} + +func GenerateLayout(args *config.Argument) error { + lg := &generator.LayoutGenerator{ + TemplateGenerator: generator.TemplateGenerator{ + OutputDir: args.OutDir, + Excludes: args.Excludes, + }, + } + + layout := generator.Layout{ + GoModule: args.Gomod, + ServiceName: args.ServiceName, + UseApacheThrift: args.IdlType == meta.IdlThrift, + HasIdl: 0 != len(args.IdlPaths), + ModelDir: args.ModelDir, + HandlerDir: args.HandlerDir, + RouterDir: args.RouterDir, + NeedGoMod: args.NeedGoMod, + } + + if args.CustomizeLayout == "" { + // generate by default + err := lg.GenerateByService(layout) + if err != nil { + return fmt.Errorf("generating layout failed: %v", err) + } + } else { + // generate by customized layout + configPath, dataPath := args.CustomizeLayout, args.CustomizeLayoutData + logs.Infof("get customized layout info, layout_config_path: %s, template_data_path: %s", configPath, dataPath) + exist, err := util.PathExist(configPath) + if err != nil { + return fmt.Errorf("check customized layout config file exist failed: %v", err) + } + if !exist { + return errors.New("layout_config_path doesn't exist") + } + lg.ConfigPath = configPath + // generate by service info + if dataPath == "" { + err := lg.GenerateByService(layout) + if err != nil { + return fmt.Errorf("generating layout failed: %v", err) + } + } else { + // generate by customized data + err := lg.GenerateByConfig(dataPath) + if err != nil { + return fmt.Errorf("generating layout failed: %v", err) + } + } + } + + err := lg.Persist() + if err != nil { + return fmt.Errorf("generating layout failed: %v", err) + } + return nil +} + +func TriggerPlugin(args *config.Argument) error { + if len(args.IdlPaths) == 0 { + return nil + } + cmd, err := config.BuildPluginCmd(args) + if err != nil { + return fmt.Errorf("build plugin command failed: %v", err) + } + + compiler, err := config.IdlTypeToCompiler(args.IdlType) + if err != nil { + return fmt.Errorf("get compiler failed: %v", err) + } + + logs.Debugf("begin to trigger plugin, compiler: %s, idl_paths: %v", compiler, args.IdlPaths) + buf, err := cmd.CombinedOutput() + if err != nil { + out := strings.TrimSpace(string(buf)) + if !strings.HasSuffix(out, meta.TheUseOptionMessage) { + return fmt.Errorf("plugin %s_gen_hertz returns error: %v, cause:\n%v", compiler, err, string(buf)) + } + } + + // If len(buf) != 0, the plugin returned the log. + if len(buf) != 0 { + fmt.Println(string(buf)) + } + logs.Debugf("end run plugin %s_gen_hertz", compiler) + return nil +} diff --git a/config/argument.go b/config/argument.go new file mode 100644 index 0000000..defc505 --- /dev/null +++ b/config/argument.go @@ -0,0 +1,397 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package config + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/cloudwego/hertz/cmd/hz/meta" + "github.com/cloudwego/hertz/cmd/hz/util" + "github.com/cloudwego/hertz/cmd/hz/util/logs" + "github.com/urfave/cli/v2" +) + +type Argument struct { + // Mode meta.Mode // operating mode(0-compiler, 1-plugin) + CmdType string // command type + Verbose bool // print verbose log + Cwd string // execution path + OutDir string // output path + HandlerDir string // handler path + ModelDir string // model path + RouterDir string // router path + ClientDir string // client path + BaseDomain string // request domain + ForceClientDir string // client dir (not use namespace as a subpath) + + IdlType string // idl type + IdlPaths []string // master idl path + RawOptPkg []string // user-specified package import path + OptPkgMap map[string]string + Includes []string + PkgPrefix string + + Gopath string // $GOPATH + Gosrc string // $GOPATH/src + Gomod string + Gopkg string // $GOPATH/src/{{gopkg}} + ServiceName string // service name + Use string + NeedGoMod bool + + JSONEnumStr bool + UnsetOmitempty bool + ProtobufCamelJSONTag bool + ProtocOptions []string // options to pass through to protoc + ThriftOptions []string // options to pass through to thriftgo for go flag + ProtobufPlugins []string + ThriftPlugins []string + SnakeName bool + RmTags []string + Excludes []string + NoRecurse bool + HandlerByMethod bool + ForceNew bool + SnakeStyleMiddleware bool + EnableExtends bool + + CustomizeLayout string + CustomizeLayoutData string + CustomizePackage string + ModelBackend string +} + +func NewArgument() *Argument { + return &Argument{ + OptPkgMap: make(map[string]string), + Includes: make([]string, 0, 4), + Excludes: make([]string, 0, 4), + ProtocOptions: make([]string, 0, 4), + ThriftOptions: make([]string, 0, 4), + } +} + +// Parse initializes a new argument based on its own information +func (arg *Argument) Parse(c *cli.Context, cmd string) (*Argument, error) { + // v2 cli cannot put the StringSlice flag to struct, so we need to parse it here + arg.parseStringSlice(c) + args := arg.Fork() + args.CmdType = cmd + + err := args.checkPath() + if err != nil { + return nil, err + } + + err = args.checkIDL() + if err != nil { + return nil, err + } + + err = args.checkPackage() + if err != nil { + return nil, err + } + + return args, nil +} + +func (arg *Argument) parseStringSlice(c *cli.Context) { + arg.IdlPaths = c.StringSlice("idl") + arg.Includes = c.StringSlice("proto_path") + arg.Excludes = c.StringSlice("exclude_file") + arg.RawOptPkg = c.StringSlice("option_package") + arg.ThriftOptions = c.StringSlice("thriftgo") + arg.ProtocOptions = c.StringSlice("protoc") + arg.ThriftPlugins = c.StringSlice("thrift-plugins") + arg.ProtobufPlugins = c.StringSlice("protoc-plugins") + arg.RmTags = c.StringSlice("rm_tag") +} + +func (arg *Argument) UpdateByManifest(m *meta.Manifest) { + if arg.HandlerDir == "" && m.HandlerDir != "" { + logs.Infof("use \"handler_dir\" in \".hz\" as the handler generated dir\n") + arg.HandlerDir = m.HandlerDir + } + if arg.ModelDir == "" && m.ModelDir != "" { + logs.Infof("use \"model_dir\" in \".hz\" as the model generated dir\n") + arg.ModelDir = m.ModelDir + } + if len(m.RouterDir) != 0 { + logs.Infof("use \"router_dir\" in \".hz\" as the router generated dir\n") + arg.RouterDir = m.RouterDir + } +} + +// checkPath sets the project path and verifies that the model、handler、router and client path is compliant +func (arg *Argument) checkPath() error { + dir, err := os.Getwd() + if err != nil { + return fmt.Errorf("get current path failed: %s", err) + } + arg.Cwd = dir + if arg.OutDir == "" { + arg.OutDir = dir + } + if !filepath.IsAbs(arg.OutDir) { + ap := filepath.Join(arg.Cwd, arg.OutDir) + arg.OutDir = ap + } + if arg.ModelDir != "" && filepath.IsAbs(arg.ModelDir) { + return fmt.Errorf("model path %s must be relative to out_dir", arg.ModelDir) + } + if arg.HandlerDir != "" && filepath.IsAbs(arg.HandlerDir) { + return fmt.Errorf("handler path %s must be relative to out_dir", arg.HandlerDir) + } + if arg.RouterDir != "" && filepath.IsAbs(arg.RouterDir) { + return fmt.Errorf("router path %s must be relative to out_dir", arg.RouterDir) + } + if arg.ClientDir != "" && filepath.IsAbs(arg.ClientDir) { + return fmt.Errorf("router path %s must be relative to out_dir", arg.ClientDir) + } + return nil +} + +// checkIDL check if the idl path exists, set and check the idl type +func (arg *Argument) checkIDL() error { + for i, path := range arg.IdlPaths { + abPath, err := filepath.Abs(path) + if err != nil { + return fmt.Errorf("idl path %s is not absolute", path) + } + ext := filepath.Ext(abPath) + if ext == "" || ext[0] != '.' { + return fmt.Errorf("idl path %s is not a valid file", path) + } + ext = ext[1:] + switch ext { + case meta.IdlThrift: + arg.IdlType = meta.IdlThrift + case meta.IdlProto: + arg.IdlType = meta.IdlProto + default: + return fmt.Errorf("IDL type %s is not supported", ext) + } + arg.IdlPaths[i] = abPath + } + return nil +} + +func (arg *Argument) IsUpdate() bool { + return arg.CmdType == meta.CmdUpdate +} + +func (arg *Argument) IsNew() bool { + return arg.CmdType == meta.CmdNew +} + +// checkPackage check and set the gopath、 module and package name +func (arg *Argument) checkPackage() error { + gopath, err := util.GetGOPATH() + if err != nil { + return fmt.Errorf("get gopath failed: %s", err) + } + if gopath == "" { + return fmt.Errorf("GOPATH is not set") + } + + arg.Gopath = gopath + arg.Gosrc = filepath.Join(gopath, "src") + + // Generate the project under gopath, use the relative path as the package name + if strings.HasPrefix(arg.Cwd, arg.Gosrc) { + if gopkg, err := filepath.Rel(arg.Gosrc, arg.Cwd); err != nil { + return fmt.Errorf("get relative path to GOPATH/src failed: %s", err) + } else { + arg.Gopkg = gopkg + } + } + if len(arg.Gomod) == 0 { // not specified "go module" + // search go.mod recursively + module, path, ok := util.SearchGoMod(arg.Cwd, true) + if ok { // find go.mod in upper level, use it as project module, don't generate go.mod + rel, err := filepath.Rel(path, arg.Cwd) + if err != nil { + return fmt.Errorf("can not get relative path, err :%v", err) + } + arg.Gomod = filepath.Join(module, rel) + logs.Debugf("find module '%s' from '%s/go.mod', so use it as module name", module, path) + } + if len(arg.Gomod) == 0 { // don't find go.mod in upper level, use relative path as module name, generate go.mod + logs.Debugf("use gopath's relative path '%s' as the module name", arg.Gopkg) + // gopkg will be "" under non-gopath + arg.Gomod = arg.Gopkg + arg.NeedGoMod = true + } + arg.Gomod = util.PathToImport(arg.Gomod, "") + } else { // specified "go module" + // search go.mod in current path + module, path, ok := util.SearchGoMod(arg.Cwd, false) + if ok { // go.mod exists in current path, check module name, don't generate go.mod + if module != arg.Gomod { + return fmt.Errorf("module name given by the '-module/mod' option ('%s') is not consist with the name defined in go.mod ('%s' from %s), try to remove '-module/mod' option in your command\n", arg.Gomod, module, path) + } + } else { // go.mod don't exist in current path, generate go.mod + arg.NeedGoMod = true + } + } + + if len(arg.Gomod) == 0 { + return fmt.Errorf("can not get go module, please specify a module name with the '-module/mod' flag") + } + + if len(arg.RawOptPkg) > 0 { + arg.OptPkgMap = make(map[string]string, len(arg.RawOptPkg)) + for _, op := range arg.RawOptPkg { + ps := strings.SplitN(op, "=", 2) + if len(ps) != 2 { + return fmt.Errorf("invalid option package: %s", op) + } + arg.OptPkgMap[ps[0]] = ps[1] + } + arg.RawOptPkg = nil + } + return nil +} + +func (arg *Argument) Pack() ([]string, error) { + data, err := util.PackArgs(arg) + if err != nil { + return nil, fmt.Errorf("pack argument failed: %s", err) + } + return data, nil +} + +func (arg *Argument) Unpack(data []string) error { + err := util.UnpackArgs(data, arg) + if err != nil { + return fmt.Errorf("unpack argument failed: %s", err) + } + return nil +} + +// Fork can copy its own parameters to a new argument +func (arg *Argument) Fork() *Argument { + args := NewArgument() + *args = *arg + util.CopyString2StringMap(arg.OptPkgMap, args.OptPkgMap) + util.CopyStringSlice(&arg.Includes, &args.Includes) + util.CopyStringSlice(&arg.Excludes, &args.Excludes) + util.CopyStringSlice(&arg.ProtocOptions, &args.ProtocOptions) + util.CopyStringSlice(&arg.ThriftOptions, &args.ThriftOptions) + return args +} + +func (arg *Argument) GetGoPackage() (string, error) { + if arg.Gomod != "" { + return arg.Gomod, nil + } else if arg.Gopkg != "" { + return arg.Gopkg, nil + } + return "", fmt.Errorf("project package name is not set") +} + +func IdlTypeToCompiler(idlType string) (string, error) { + switch idlType { + case meta.IdlProto: + return meta.TpCompilerProto, nil + case meta.IdlThrift: + return meta.TpCompilerThrift, nil + default: + return "", fmt.Errorf("IDL type %s is not supported", idlType) + } +} + +func (arg *Argument) ModelPackagePrefix() (string, error) { + ret := arg.Gomod + if arg.ModelDir == "" { + path, err := util.RelativePath(meta.ModelDir) + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + if err != nil { + return "", err + } + ret += path + } else { + path, err := util.RelativePath(arg.ModelDir) + if err != nil { + return "", err + } + ret += "/" + path + } + return strings.ReplaceAll(ret, string(filepath.Separator), "/"), nil +} + +func (arg *Argument) ModelOutDir() string { + ret := arg.OutDir + if arg.ModelDir == "" { + ret = filepath.Join(ret, meta.ModelDir) + } else { + ret = filepath.Join(ret, arg.ModelDir) + } + return ret +} + +func (arg *Argument) GetHandlerDir() (string, error) { + if arg.HandlerDir == "" { + return util.RelativePath(meta.HandlerDir) + } + return util.RelativePath(arg.HandlerDir) +} + +func (arg *Argument) GetModelDir() (string, error) { + if arg.ModelDir == "" { + return util.RelativePath(meta.ModelDir) + } + return util.RelativePath(arg.ModelDir) +} + +func (arg *Argument) GetRouterDir() (string, error) { + if arg.RouterDir == "" { + return util.RelativePath(meta.RouterDir) + } + return util.RelativePath(arg.RouterDir) +} + +func (arg *Argument) GetClientDir() (string, error) { + if arg.ClientDir == "" { + return "", nil + } + return util.RelativePath(arg.ClientDir) +} + +func (arg *Argument) InitManifest(m *meta.Manifest) { + m.Version = meta.Version + m.HandlerDir = arg.HandlerDir + m.ModelDir = arg.ModelDir + m.RouterDir = arg.RouterDir +} + +func (arg *Argument) UpdateManifest(m *meta.Manifest) { + m.Version = meta.Version + if arg.HandlerDir != m.HandlerDir { + m.HandlerDir = arg.HandlerDir + } + if arg.HandlerDir != m.ModelDir { + m.ModelDir = arg.ModelDir + } + // "router_dir" must not be defined by "update" command +} diff --git a/config/cmd.go b/config/cmd.go new file mode 100644 index 0000000..fc83c57 --- /dev/null +++ b/config/cmd.go @@ -0,0 +1,187 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package config + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "syscall" + + "github.com/cloudwego/hertz/cmd/hz/meta" + "github.com/cloudwego/hertz/cmd/hz/util" + "github.com/cloudwego/hertz/cmd/hz/util/logs" +) + +func lookupTool(idlType string) (string, error) { + tool := meta.TpCompilerThrift + if idlType == meta.IdlProto { + tool = meta.TpCompilerProto + } + + path, err := exec.LookPath(tool) + logs.Debugf("[DEBUG]path:%v", path) + if err != nil { + goPath, err := util.GetGOPATH() + if err != nil { + return "", fmt.Errorf("get 'GOPATH' failed for find %s : %v", tool, path) + } + path = filepath.Join(goPath, "bin", tool) + } + + isExist, err := util.PathExist(path) + if err != nil { + return "", fmt.Errorf("check '%s' path error: %v", path, err) + } + + if !isExist { + if tool == meta.TpCompilerThrift { + // If thriftgo does not exist, the latest version will be installed automatically. + err := util.InstallAndCheckThriftgo() + if err != nil { + return "", fmt.Errorf("can't install '%s' automatically, please install it manually for https://github.com/cloudwego/thriftgo, err : %v", tool, err) + } + } else { + // todo: protoc automatic installation + return "", fmt.Errorf("%s is not installed, please install it first", tool) + } + } + + if tool == meta.TpCompilerThrift { + // If thriftgo exists, the version is detected; if the version is lower than v0.2.0 then the latest version of thriftgo is automatically installed. + err := util.CheckAndUpdateThriftgo() + if err != nil { + return "", fmt.Errorf("update thriftgo version failed, please install it manually for https://github.com/cloudwego/thriftgo, err: %v", err) + } + } + + return path, nil +} + +// link removes the previous symbol link and rebuilds a new one. +func link(src, dst string) error { + err := syscall.Unlink(dst) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("unlink %q: %s", dst, err) + } + err = os.Symlink(src, dst) + if err != nil { + return fmt.Errorf("symlink %q: %s", dst, err) + } + return nil +} + +func BuildPluginCmd(args *Argument) (*exec.Cmd, error) { + exe, err := os.Executable() + if err != nil { + return nil, fmt.Errorf("failed to detect current executable, err: %v", err) + } + + argPacks, err := args.Pack() + if err != nil { + return nil, err + } + kas := strings.Join(argPacks, ",") + + path, err := lookupTool(args.IdlType) + if err != nil { + return nil, err + } + cmd := &exec.Cmd{ + Path: path, + } + + if args.IdlType == meta.IdlThrift { + // thriftgo + os.Setenv(meta.EnvPluginMode, meta.ThriftPluginName) + cmd.Args = append(cmd.Args, meta.TpCompilerThrift) + for _, inc := range args.Includes { + cmd.Args = append(cmd.Args, "-i", inc) + } + + if args.Verbose { + cmd.Args = append(cmd.Args, "-v") + } + thriftOpt, err := args.GetThriftgoOptions() + if err != nil { + return nil, err + } + cmd.Args = append(cmd.Args, + "-o", args.ModelOutDir(), + "-g", thriftOpt, + "-p", "hertz="+exe+":"+kas, + ) + for _, p := range args.ThriftPlugins { + cmd.Args = append(cmd.Args, "-p", p) + } + if !args.NoRecurse { + cmd.Args = append(cmd.Args, "-r") + } + } else { + // protoc + os.Setenv(meta.EnvPluginMode, meta.ProtocPluginName) + cmd.Args = append(cmd.Args, meta.TpCompilerProto) + for _, inc := range args.Includes { + cmd.Args = append(cmd.Args, "-I", inc) + } + for _, inc := range args.IdlPaths { + cmd.Args = append(cmd.Args, "-I", filepath.Dir(inc)) + } + cmd.Args = append(cmd.Args, + "--plugin=protoc-gen-hertz="+exe, + "--hertz_out="+args.OutDir, + "--hertz_opt="+kas, + ) + for _, p := range args.ProtobufPlugins { + pluginParams := strings.Split(p, ":") + if len(pluginParams) != 3 { + logs.Warnf("Failed to get the correct protoc plugin parameters for %. "+ + "Please specify the protoc plugin in the form of \"plugin_name:options:out_dir\"", p) + os.Exit(1) + } + // pluginParams[0] -> plugin name, pluginParams[1] -> plugin options, pluginParams[2] -> out_dir + cmd.Args = append(cmd.Args, + fmt.Sprintf("--%s_out=%s", pluginParams[0], pluginParams[2]), + fmt.Sprintf("--%s_opt=%s", pluginParams[0], pluginParams[1]), + ) + } + for _, kv := range args.ProtocOptions { + cmd.Args = append(cmd.Args, "--"+kv) + } + } + + cmd.Args = append(cmd.Args, args.IdlPaths...) + logs.Infof(strings.Join(cmd.Args, " ")) + logs.Flush() + return cmd, nil +} + +func (arg *Argument) GetThriftgoOptions() (string, error) { + defaultOpt := "reserve_comments,gen_json_tag=false," + prefix, err := arg.ModelPackagePrefix() + if err != nil { + return "", err + } + arg.ThriftOptions = append(arg.ThriftOptions, "package_prefix="+prefix) + if arg.JSONEnumStr { + arg.ThriftOptions = append(arg.ThriftOptions, "json_enum_as_text") + } + gas := "go:" + defaultOpt + strings.Join(arg.ThriftOptions, ",") + return gas, nil +} diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..6eff898 --- /dev/null +++ b/doc.go @@ -0,0 +1,20 @@ +// Copyright 2022 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package "github.com/cloudwego/hertz/cmd/hz" contains packages for building the hz command line tool. +// APIs exported by packages under this directory do not promise any backward +// compatibility, so please do not rely on them. + +package main diff --git a/generator/client.go b/generator/client.go new file mode 100644 index 0000000..ff32c4b --- /dev/null +++ b/generator/client.go @@ -0,0 +1,100 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package generator + +import ( + "path/filepath" + "strings" + + "github.com/cloudwego/hertz/cmd/hz/generator/model" + "github.com/cloudwego/hertz/cmd/hz/util" +) + +type ClientMethod struct { + *HttpMethod + BodyParamsCode string + QueryParamsCode string + PathParamsCode string + HeaderParamsCode string + FormValueCode string + FormFileCode string +} + +type ClientFile struct { + FilePath string + PackageName string + ServiceName string + BaseDomain string + Imports map[string]*model.Model + ClientMethods []*ClientMethod +} + +func (pkgGen *HttpPackageGenerator) genClient(pkg *HttpPackage, clientDir string) error { + for _, s := range pkg.Services { + cliDir := util.SubDir(clientDir, util.ToSnakeCase(s.Name)) + if len(pkgGen.ForceClientDir) != 0 { + cliDir = pkgGen.ForceClientDir + } + hertzClientPath := filepath.Join(cliDir, hertzClientTplName) + isExist, err := util.PathExist(hertzClientPath) + if err != nil { + return err + } + baseDomain := s.BaseDomain + if len(pkgGen.BaseDomain) != 0 { + baseDomain = pkgGen.BaseDomain + } + client := ClientFile{ + FilePath: filepath.Join(cliDir, util.ToSnakeCase(s.Name)+".go"), + PackageName: util.ToSnakeCase(filepath.Base(cliDir)), + ServiceName: util.ToCamelCase(s.Name), + ClientMethods: s.ClientMethods, + BaseDomain: baseDomain, + } + if !isExist { + err := pkgGen.TemplateGenerator.Generate(client, hertzClientTplName, hertzClientPath, false) + if err != nil { + return err + } + } + client.Imports = make(map[string]*model.Model, len(client.ClientMethods)) + for _, m := range client.ClientMethods { + // Iterate over the request and return parameters of the method to get import path. + for key, mm := range m.Models { + if v, ok := client.Imports[mm.PackageName]; ok && v.Package != mm.Package { + client.Imports[key] = mm + continue + } + client.Imports[mm.PackageName] = mm + } + } + if len(pkgGen.UseDir) != 0 { + oldModelDir := filepath.Clean(filepath.Join(pkgGen.ProjPackage, pkgGen.ModelDir)) + newModelDir := filepath.Clean(pkgGen.UseDir) + for _, m := range client.ClientMethods { + for _, mm := range m.Models { + mm.Package = strings.Replace(mm.Package, oldModelDir, newModelDir, 1) + } + } + } + err = pkgGen.TemplateGenerator.Generate(client, idlClientName, client.FilePath, false) + if err != nil { + return err + } + } + return nil +} diff --git a/generator/custom_files.go b/generator/custom_files.go new file mode 100644 index 0000000..1315948 --- /dev/null +++ b/generator/custom_files.go @@ -0,0 +1,657 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package generator + +import ( + "bytes" + "fmt" + "io/ioutil" + "path/filepath" + "strings" + "text/template" + + "github.com/cloudwego/hertz/cmd/hz/util" + "github.com/cloudwego/hertz/cmd/hz/util/logs" +) + +type FilePathRenderInfo struct { + MasterIDLName string // master IDL name + GenPackage string // master IDL generate code package + HandlerDir string // handler generate dir + ModelDir string // model generate dir + RouterDir string // router generate dir + ProjectDir string // projectDir + GoModule string // go module + ServiceName string // service name, changed as services are traversed + MethodName string // method name, changed as methods are traversed + HandlerGenPath string // "api.gen_path" value +} + +type IDLPackageRenderInfo struct { + FilePathRenderInfo + ServiceInfos *HttpPackage +} + +type CustomizedFileForMethod struct { + *HttpMethod + FilePath string + FilePackage string + ServiceInfo *Service // service info for this method + IDLPackageInfo *IDLPackageRenderInfo // IDL info for this service +} + +type CustomizedFileForService struct { + *Service + FilePath string + FilePackage string + IDLPackageInfo *IDLPackageRenderInfo // IDL info for this service +} + +type CustomizedFileForIDL struct { + *IDLPackageRenderInfo + FilePath string + FilePackage string +} + +// todo: 1. how to import other file, if the other file name is a template + +// genCustomizedFile generate customized file template +func (pkgGen *HttpPackageGenerator) genCustomizedFile(pkg *HttpPackage) error { + filePathRenderInfo := FilePathRenderInfo{ + MasterIDLName: pkg.IdlName, + GenPackage: pkg.Package, + HandlerDir: pkgGen.HandlerDir, + ModelDir: pkgGen.ModelDir, + RouterDir: pkgGen.RouterDir, + ProjectDir: pkgGen.OutputDir, + GoModule: pkgGen.ProjPackage, + // methodName & serviceName will change as traverse + } + + idlPackageRenderInfo := IDLPackageRenderInfo{ + FilePathRenderInfo: filePathRenderInfo, + ServiceInfos: pkg, + } + + for _, tplInfo := range pkgGen.tplsInfo { + // the default template has been automatically generated by the tool, so skip + if tplInfo.Default { + continue + } + + // loop generate file + if tplInfo.LoopService || tplInfo.LoopMethod { + loopMethod := tplInfo.LoopMethod + loopService := tplInfo.LoopService + if loopService && !loopMethod { // only loop service + for _, service := range idlPackageRenderInfo.ServiceInfos.Services { + filePathRenderInfo.ServiceName = service.Name + err := pkgGen.genLoopService(tplInfo, filePathRenderInfo, service, &idlPackageRenderInfo) + if err != nil { + return err + } + } + } else { // loop service & method, because if loop method, the service must be looped + for _, service := range idlPackageRenderInfo.ServiceInfos.Services { + for _, method := range service.Methods { + filePathRenderInfo.ServiceName = service.Name + filePathRenderInfo.MethodName = method.Name + filePathRenderInfo.HandlerGenPath = method.OutputDir + err := pkgGen.genLoopMethod(tplInfo, filePathRenderInfo, method, service, &idlPackageRenderInfo) + if err != nil { + return err + } + } + } + } + } else { // generate customized file single + err := pkgGen.genSingleCustomizedFile(tplInfo, filePathRenderInfo, idlPackageRenderInfo) + if err != nil { + return err + } + } + } + return nil +} + +// renderFilePath used to render file path template to get real file path +func renderFilePath(tplInfo *Template, filePathRenderInfo FilePathRenderInfo) (string, error) { + tpl, err := template.New(tplInfo.Path).Funcs(funcMap).Parse(tplInfo.Path) + if err != nil { + return "", fmt.Errorf("parse file path template(%s) failed, err: %v", tplInfo.Path, err) + } + filePath := bytes.NewBuffer(nil) + err = tpl.Execute(filePath, filePathRenderInfo) + if err != nil { + return "", fmt.Errorf("render file path template(%s) failed, err: %v", tplInfo.Path, err) + } + + return filePath.String(), nil +} + +func renderInsertKey(tplInfo *Template, data interface{}) (string, error) { + tpl, err := template.New(tplInfo.UpdateBehavior.InsertKey).Funcs(funcMap).Parse(tplInfo.UpdateBehavior.InsertKey) + if err != nil { + return "", fmt.Errorf("parse insert key template(%s) failed, err: %v", tplInfo.UpdateBehavior.InsertKey, err) + } + insertKey := bytes.NewBuffer(nil) + err = tpl.Execute(insertKey, data) + if err != nil { + return "", fmt.Errorf("render insert key template(%s) failed, err: %v", tplInfo.UpdateBehavior.InsertKey, err) + } + + return insertKey.String(), nil +} + +// renderImportTpl will render import template +// it will return the []string, like blow: +// ["import", alias "import", import] +// other format will be error +func renderImportTpl(tplInfo *Template, data interface{}) ([]string, error) { + var importList []string + for _, impt := range tplInfo.UpdateBehavior.ImportTpl { + tpl, err := template.New(impt).Funcs(funcMap).Parse(impt) + if err != nil { + return nil, fmt.Errorf("parse import template(%s) failed, err: %v", impt, err) + } + imptContent := bytes.NewBuffer(nil) + err = tpl.Execute(imptContent, data) + if err != nil { + return nil, fmt.Errorf("render import template(%s) failed, err: %v", impt, err) + } + importList = append(importList, imptContent.String()) + } + var ret []string + for _, impts := range importList { + // 'import render result' may have multiple imports + if strings.Contains(impts, "\n") { + for _, impt := range strings.Split(impts, "\n") { + ret = append(ret, impt) + } + } else { + ret = append(ret, impts) + } + } + + return ret, nil +} + +// renderAppendContent used to render append content for 'update' command +func renderAppendContent(tplInfo *Template, renderInfo interface{}) (string, error) { + tpl, err := template.New(tplInfo.Path).Funcs(funcMap).Parse(tplInfo.UpdateBehavior.AppendTpl) + if err != nil { + return "", fmt.Errorf("parse append content template(%s) failed, err: %v", tplInfo.Path, err) + } + appendContent := bytes.NewBuffer(nil) + err = tpl.Execute(appendContent, renderInfo) + if err != nil { + return "", fmt.Errorf("render append content template(%s) failed, err: %v", tplInfo.Path, err) + } + + return appendContent.String(), nil +} + +// appendUpdateFile used to append content to file for 'update' command +func appendUpdateFile(tplInfo *Template, renderInfo interface{}, fileContent []byte) ([]byte, error) { + // render insert content + appendContent, err := renderAppendContent(tplInfo, renderInfo) + if err != nil { + return []byte(""), err + } + buf := bytes.NewBuffer(nil) + _, err = buf.Write(fileContent) + if err != nil { + return []byte(""), fmt.Errorf("write file(%s) failed, err: %v", tplInfo.Path, err) + } + // "\r\n" && "\n" has the same suffix + if !bytes.HasSuffix(buf.Bytes(), []byte("\n")) { + _, err = buf.WriteString("\n") + if err != nil { + return []byte(""), fmt.Errorf("write file(%s) line break failed, err: %v", tplInfo.Path, err) + } + } + _, err = buf.WriteString(appendContent) + if err != nil { + return []byte(""), fmt.Errorf("append file(%s) failed, err: %v", tplInfo.Path, err) + } + + return buf.Bytes(), nil +} + +func getInsertImportContent(tplInfo *Template, renderInfo interface{}, fileContent []byte) ([][2]string, error) { + importContent, err := renderImportTpl(tplInfo, renderInfo) + if err != nil { + return nil, err + } + var imptSlice [][2]string + for _, impt := range importContent { + // import has to format + // 1. alias "import" + // 2. "import" + // 3. import (can not contain '"') + impt = strings.TrimSpace(impt) + if !strings.Contains(impt, "\"") { // 3. import (can not contain '"') + if bytes.Contains(fileContent, []byte(impt)) { + continue + } + imptSlice = append(imptSlice, [2]string{"", impt}) + } else { + if !strings.HasSuffix(impt, "\"") { + return nil, fmt.Errorf("import can not has suffix \"\"\", for file: %s", tplInfo.Path) + } + if strings.HasPrefix(impt, "\"") { // 2. "import" + if bytes.Contains(fileContent, []byte(impt[1:len(impt)-1])) { + continue + } + imptSlice = append(imptSlice, [2]string{"", impt[1 : len(impt)-1]}) + } else { // 3. alias "import" + idx := strings.Index(impt, "\n") + if idx == -1 { + return nil, fmt.Errorf("error import format for file: %s", tplInfo.Path) + } + if bytes.Contains(fileContent, []byte(impt[idx+1:len(impt)-1])) { + continue + } + imptSlice = append(imptSlice, [2]string{impt[:idx], impt[idx+1 : len(impt)-1]}) + } + } + } + + return imptSlice, nil +} + +// genLoopService used to generate files by 'service' +func (pkgGen *HttpPackageGenerator) genLoopService(tplInfo *Template, filePathRenderInfo FilePathRenderInfo, service *Service, idlPackageRenderInfo *IDLPackageRenderInfo) error { + filePath, err := renderFilePath(tplInfo, filePathRenderInfo) + if err != nil { + return err + } + // determine if a custom file exists + exist, err := util.PathExist(filePath) + if err != nil { + return fmt.Errorf("judge file(%s) exists failed, err: %v", filePath, err) + } + if !exist { // create file + data := CustomizedFileForService{ + Service: service, + FilePath: filePath, + FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), + IDLPackageInfo: idlPackageRenderInfo, + } + err = pkgGen.TemplateGenerator.Generate(data, tplInfo.Path, filePath, false) + if err != nil { + return err + } + } else { + switch tplInfo.UpdateBehavior.Type { + case Skip: + // do nothing + logs.Infof("do not update file '%s', because the update behavior is 'Unchanged'", filePath) + case Cover: + // re-generate + logs.Infof("re-generate file '%s', because the update behavior is 'Regenerate'", filePath) + data := CustomizedFileForService{ + Service: service, + FilePath: filePath, + FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), + IDLPackageInfo: idlPackageRenderInfo, + } + err := pkgGen.TemplateGenerator.Generate(data, tplInfo.Path, filePath, false) + if err != nil { + return err + } + case Append: // todo: append logic need to be optimized for method + fileContent, err := ioutil.ReadFile(filePath) + if err != nil { + return err + } + var appendContent []byte + // get insert content + if tplInfo.UpdateBehavior.AppendKey == "method" { + for _, method := range service.Methods { + data := CustomizedFileForMethod{ + HttpMethod: method, + FilePath: filePath, + FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), + ServiceInfo: service, + IDLPackageInfo: idlPackageRenderInfo, + } + insertKey, err := renderInsertKey(tplInfo, data) + if err != nil { + return err + } + if bytes.Contains(fileContent, []byte(insertKey)) { + continue + } + imptSlice, err := getInsertImportContent(tplInfo, data, fileContent) + if err != nil { + return err + } + // insert new import to the fileContent + for _, impt := range imptSlice { + if bytes.Contains(fileContent, []byte(impt[1])) { + continue + } + fileContent, err = util.AddImportForContent(fileContent, impt[0], impt[1]) + // insert import error do not influence the generated file + if err != nil { + logs.Warnf("can not add import(%s) for file(%s), err: %v\n", impt[1], filePath, err) + } + } + appendContent, err = appendUpdateFile(tplInfo, data, appendContent) + if err != nil { + return err + } + } + if len(tplInfo.UpdateBehavior.AppendLocation) == 0 { // default, append to end of file + buf := bytes.NewBuffer(nil) + _, err = buf.Write(fileContent) + if err != nil { + return fmt.Errorf("write file(%s) failed, err: %v", tplInfo.Path, err) + } + _, err = buf.Write(appendContent) + if err != nil { + return fmt.Errorf("append file(%s) failed, err: %v", tplInfo.Path, err) + } + logs.Infof("append content for file '%s', because the update behavior is 'Append' and appendKey is 'method'", filePath) + pkgGen.files = append(pkgGen.files, File{filePath, buf.String(), false, ""}) + } else { // 'append location', append new content after 'append location' + part := bytes.Split(fileContent, []byte(tplInfo.UpdateBehavior.AppendLocation)) + if len(part) == 0 { + return fmt.Errorf("can not find append location '%s' for file '%s'\n", tplInfo.UpdateBehavior.AppendLocation, filePath) + } + if len(part) != 2 { + return fmt.Errorf("do not support multiple append location '%s' for file '%s'\n", tplInfo.UpdateBehavior.AppendLocation, filePath) + } + buf := bytes.NewBuffer(nil) + err = writeBytes(buf, part[0], []byte(tplInfo.UpdateBehavior.AppendLocation), appendContent, part[1]) + if err != nil { + return fmt.Errorf("write file(%s) failed, err: %v", tplInfo.Path, err) + } + logs.Infof("append content for file '%s', because the update behavior is 'Append' and appendKey is 'method'", filePath) + pkgGen.files = append(pkgGen.files, File{filePath, buf.String(), false, ""}) + } + } else { + logs.Warnf("Loop 'service' field for '%s' only append content by appendKey for 'method', so cannot append content", filePath) + } + default: + // do nothing + logs.Warnf("unknown update behavior, do nothing") + } + } + return nil +} + +// genLoopMethod used to generate files by 'method' +func (pkgGen *HttpPackageGenerator) genLoopMethod(tplInfo *Template, filePathRenderInfo FilePathRenderInfo, method *HttpMethod, service *Service, idlPackageRenderInfo *IDLPackageRenderInfo) error { + filePath, err := renderFilePath(tplInfo, filePathRenderInfo) + if err != nil { + return err + } + // determine if a custom file exists + exist, err := util.PathExist(filePath) + if err != nil { + return fmt.Errorf("judge file(%s) exists failed, err: %v", filePath, err) + } + + if !exist { // create file + data := CustomizedFileForMethod{ + HttpMethod: method, + FilePath: filePath, + FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), + ServiceInfo: service, + IDLPackageInfo: idlPackageRenderInfo, + } + err := pkgGen.TemplateGenerator.Generate(data, tplInfo.Path, filePath, false) + if err != nil { + return err + } + } else { + switch tplInfo.UpdateBehavior.Type { + case Skip: + // do nothing + logs.Infof("do not update file '%s', because the update behavior is 'Unchanged'", filePath) + case Cover: + // re-generate + logs.Infof("re-generate file '%s', because the update behavior is 'Regenerate'", filePath) + data := CustomizedFileForMethod{ + HttpMethod: method, + FilePath: filePath, + FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), + ServiceInfo: service, + IDLPackageInfo: idlPackageRenderInfo, + } + err := pkgGen.TemplateGenerator.Generate(data, tplInfo.Path, filePath, false) + if err != nil { + return err + } + case Append: + // for loop method, no need to append something; so do nothing + logs.Warnf("do not append content for file '%s', because the update behavior is 'Append' and loop 'method' have no need to append content", filePath) + default: + // do nothing + logs.Warnf("unknown update behavior, do nothing") + } + } + + return nil +} + +// genSingleCustomizedFile used to generate file by 'master IDL' +func (pkgGen *HttpPackageGenerator) genSingleCustomizedFile(tplInfo *Template, filePathRenderInfo FilePathRenderInfo, idlPackageRenderInfo IDLPackageRenderInfo) error { + // generate file single + filePath, err := renderFilePath(tplInfo, filePathRenderInfo) + if err != nil { + return err + } + // determine if a custom file exists + exist, err := util.PathExist(filePath) + if err != nil { + return fmt.Errorf("judge file(%s) exists failed, err: %v", filePath, err) + } + + if !exist { // create file + data := CustomizedFileForIDL{ + IDLPackageRenderInfo: &idlPackageRenderInfo, + FilePath: filePath, + FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), + } + err := pkgGen.TemplateGenerator.Generate(data, tplInfo.Path, filePath, false) + if err != nil { + return err + } + } else { + switch tplInfo.UpdateBehavior.Type { + case Skip: + // do nothing + logs.Infof("do not update file '%s', because the update behavior is 'Unchanged'", filePath) + case Cover: + // re-generate + logs.Infof("re-generate file '%s', because the update behavior is 'Regenerate'", filePath) + data := CustomizedFileForIDL{ + IDLPackageRenderInfo: &idlPackageRenderInfo, + FilePath: filePath, + FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), + } + err := pkgGen.TemplateGenerator.Generate(data, tplInfo.Path, filePath, false) + if err != nil { + return err + } + case Append: // todo: append logic need to be optimized for single file + fileContent, err := ioutil.ReadFile(filePath) + if err != nil { + return err + } + if tplInfo.UpdateBehavior.AppendKey == "method" { + var appendContent []byte + for _, service := range idlPackageRenderInfo.ServiceInfos.Services { + for _, method := range service.Methods { + data := CustomizedFileForMethod{ + HttpMethod: method, + FilePath: filePath, + FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), + ServiceInfo: service, + IDLPackageInfo: &idlPackageRenderInfo, + } + insertKey, err := renderInsertKey(tplInfo, data) + if err != nil { + return err + } + if bytes.Contains(fileContent, []byte(insertKey)) { + continue + } + imptSlice, err := getInsertImportContent(tplInfo, data, fileContent) + if err != nil { + return err + } + for _, impt := range imptSlice { + if bytes.Contains(fileContent, []byte(impt[1])) { + continue + } + fileContent, err = util.AddImportForContent(fileContent, impt[0], impt[1]) + if err != nil { + logs.Warnf("can not add import(%s) for file(%s)\n", impt[1], filePath) + } + } + + appendContent, err = appendUpdateFile(tplInfo, data, appendContent) + if err != nil { + return err + } + } + } + if len(tplInfo.UpdateBehavior.AppendLocation) == 0 { // default, append to end of file + buf := bytes.NewBuffer(nil) + _, err = buf.Write(fileContent) + if err != nil { + return fmt.Errorf("write file(%s) failed, err: %v", tplInfo.Path, err) + } + _, err = buf.Write(appendContent) + if err != nil { + return fmt.Errorf("append file(%s) failed, err: %v", tplInfo.Path, err) + } + logs.Infof("append content for file '%s', because the update behavior is 'Append' and appendKey is 'method'", filePath) + pkgGen.files = append(pkgGen.files, File{filePath, buf.String(), false, ""}) + } else { // 'append location', append new content after 'append location' + part := bytes.Split(fileContent, []byte(tplInfo.UpdateBehavior.AppendLocation)) + if len(part) == 0 { + return fmt.Errorf("can not find append location '%s' for file '%s'\n", tplInfo.UpdateBehavior.AppendLocation, filePath) + } + if len(part) != 2 { + return fmt.Errorf("do not support multiple append location '%s' for file '%s'\n", tplInfo.UpdateBehavior.AppendLocation, filePath) + } + buf := bytes.NewBuffer(nil) + err = writeBytes(buf, part[0], []byte(tplInfo.UpdateBehavior.AppendLocation), appendContent, part[1]) + if err != nil { + return fmt.Errorf("write file(%s) failed, err: %v", tplInfo.Path, err) + } + logs.Infof("append content for file '%s', because the update behavior is 'Append' and appendKey is 'method'", filePath) + pkgGen.files = append(pkgGen.files, File{filePath, buf.String(), false, ""}) + } + } else if tplInfo.UpdateBehavior.AppendKey == "service" { + var appendContent []byte + for _, service := range idlPackageRenderInfo.ServiceInfos.Services { + data := CustomizedFileForService{ + Service: service, + FilePath: filePath, + FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), + IDLPackageInfo: &idlPackageRenderInfo, + } + insertKey, err := renderInsertKey(tplInfo, data) + if err != nil { + return err + } + if bytes.Contains(fileContent, []byte(insertKey)) { + continue + } + imptSlice, err := getInsertImportContent(tplInfo, data, fileContent) + if err != nil { + return err + } + for _, impt := range imptSlice { + if bytes.Contains(fileContent, []byte(impt[1])) { + continue + } + fileContent, err = util.AddImportForContent(fileContent, impt[0], impt[1]) + if err != nil { + logs.Warnf("can not add import(%s) for file(%s)\n", impt[1], filePath) + } + } + appendContent, err = appendUpdateFile(tplInfo, data, appendContent) + if err != nil { + return err + } + } + if len(tplInfo.UpdateBehavior.AppendLocation) == 0 { // default, append to end of file + buf := bytes.NewBuffer(nil) + _, err = buf.Write(fileContent) + if err != nil { + return fmt.Errorf("write file(%s) failed, err: %v", tplInfo.Path, err) + } + _, err = buf.Write(appendContent) + if err != nil { + return fmt.Errorf("append file(%s) failed, err: %v", tplInfo.Path, err) + } + logs.Infof("append content for file '%s', because the update behavior is 'Append' and appendKey is 'service'", filePath) + pkgGen.files = append(pkgGen.files, File{filePath, buf.String(), false, ""}) + } else { // 'append location', append new content after 'append location' + part := bytes.Split(fileContent, []byte(tplInfo.UpdateBehavior.AppendLocation)) + if len(part) == 0 { + return fmt.Errorf("can not find append location '%s' for file '%s'\n", tplInfo.UpdateBehavior.AppendLocation, filePath) + } + if len(part) != 2 { + return fmt.Errorf("do not support multiple append location '%s' for file '%s'\n", tplInfo.UpdateBehavior.AppendLocation, filePath) + } + buf := bytes.NewBuffer(nil) + err = writeBytes(buf, part[0], []byte(tplInfo.UpdateBehavior.AppendLocation), appendContent, part[1]) + if err != nil { + return fmt.Errorf("write file(%s) failed, err: %v", tplInfo.Path, err) + } + logs.Infof("append content for file '%s', because the update behavior is 'Append' and appendKey is 'service'", filePath) + pkgGen.files = append(pkgGen.files, File{filePath, buf.String(), false, ""}) + } + } else { // add append content to the file directly + data := CustomizedFileForIDL{ + IDLPackageRenderInfo: &idlPackageRenderInfo, + FilePath: filePath, + FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), + } + file, err := appendUpdateFile(tplInfo, data, fileContent) + if err != nil { + return err + } + pkgGen.files = append(pkgGen.files, File{filePath, string(file), false, ""}) + } + default: + // do nothing + logs.Warnf("unknown update behavior, do nothing") + } + } + + return nil +} + +func writeBytes(buf *bytes.Buffer, bytes ...[]byte) error { + for _, b := range bytes { + _, err := buf.Write(b) + if err != nil { + return err + } + } + + return nil +} diff --git a/generator/file.go b/generator/file.go new file mode 100644 index 0000000..6e0133a --- /dev/null +++ b/generator/file.go @@ -0,0 +1,46 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package generator + +import ( + "fmt" + "go/format" + "path/filepath" + "strings" + + "github.com/cloudwego/hertz/cmd/hz/util" +) + +type File struct { + Path string + Content string + NoRepeat bool + FileTplName string +} + +// Lint is used to statically analyze and format go code +func (file *File) Lint() error { + name := filepath.Base(file.Path) + if strings.HasSuffix(name, ".go") { + out, err := format.Source(util.Str2Bytes(file.Content)) + if err != nil { + return fmt.Errorf("lint file '%s' failed, err: %v", name, err.Error()) + } + file.Content = util.Bytes2Str(out) + } + return nil +} diff --git a/generator/handler.go b/generator/handler.go new file mode 100644 index 0000000..609e1d6 --- /dev/null +++ b/generator/handler.go @@ -0,0 +1,320 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package generator + +import ( + "bytes" + "fmt" + "io/ioutil" + "path/filepath" + "strings" + + "github.com/cloudwego/hertz/cmd/hz/generator/model" + "github.com/cloudwego/hertz/cmd/hz/util" + "github.com/cloudwego/hertz/cmd/hz/util/logs" +) + +type HttpMethod struct { + Name string + HTTPMethod string + Comment string + RequestTypeName string + RequestTypePackage string + RequestTypeRawName string + ReturnTypeName string + ReturnTypePackage string + ReturnTypeRawName string + Path string + Serializer string + OutputDir string + RefPackage string // handler import dir + RefPackageAlias string // handler import alias + ModelPackage map[string]string + GenHandler bool // Whether to generate one handler, when an idl interface corresponds to multiple http method + // Annotations map[string]string + Models map[string]*model.Model +} + +type Handler struct { + FilePath string + PackageName string + ProjPackage string + Imports map[string]*model.Model + Methods []*HttpMethod +} + +type SingleHandler struct { + *HttpMethod + FilePath string + PackageName string + ProjPackage string +} + +type Client struct { + Handler + ServiceName string +} + +func (pkgGen *HttpPackageGenerator) genHandler(pkg *HttpPackage, handlerDir, handlerPackage string, root *RouterNode) error { + for _, s := range pkg.Services { + var handler Handler + if pkgGen.HandlerByMethod { // generate handler by method + for _, m := range s.Methods { + filePath := filepath.Join(handlerDir, m.OutputDir, util.ToSnakeCase(m.Name)+".go") + handler = Handler{ + FilePath: filePath, + PackageName: util.SplitPackage(filepath.Dir(filePath), ""), + Methods: []*HttpMethod{m}, + ProjPackage: pkgGen.ProjPackage, + } + + if err := pkgGen.processHandler(&handler, root, handlerDir, m.OutputDir, true); err != nil { + return fmt.Errorf("generate handler %s failed, err: %v", handler.FilePath, err.Error()) + } + + if m.GenHandler { + if err := pkgGen.updateHandler(handler, handlerTplName, handler.FilePath, false); err != nil { + return fmt.Errorf("generate handler %s failed, err: %v", handler.FilePath, err.Error()) + } + } + } + } else { // generate handler service + tmpHandlerDir := handlerDir + tmpHandlerPackage := handlerPackage + if len(s.ServiceGenDir) != 0 { + tmpHandlerDir = s.ServiceGenDir + tmpHandlerPackage = util.SubPackage(pkgGen.ProjPackage, tmpHandlerDir) + } + handler = Handler{ + FilePath: filepath.Join(tmpHandlerDir, util.ToSnakeCase(s.Name)+".go"), + PackageName: util.SplitPackage(tmpHandlerPackage, ""), + Methods: s.Methods, + ProjPackage: pkgGen.ProjPackage, + } + + for _, m := range s.Methods { + m.RefPackage = tmpHandlerPackage + m.RefPackageAlias = util.BaseName(tmpHandlerPackage, "") + } + + if err := pkgGen.processHandler(&handler, root, "", "", false); err != nil { + return fmt.Errorf("generate handler %s failed, err: %v", handler.FilePath, err.Error()) + } + + // Avoid generating duplicate handlers when IDL interface corresponds to multiple http methods + methods := handler.Methods + handler.Methods = []*HttpMethod{} + for _, m := range methods { + if m.GenHandler { + handler.Methods = append(handler.Methods, m) + } + } + + if err := pkgGen.updateHandler(handler, handlerTplName, handler.FilePath, false); err != nil { + return fmt.Errorf("generate handler %s failed, err: %v", handler.FilePath, err.Error()) + } + } + + if len(pkgGen.ClientDir) != 0 { + clientDir := util.SubDir(pkgGen.ClientDir, pkg.Package) + clientPackage := util.SubPackage(pkgGen.ProjPackage, clientDir) + client := Client{} + client.Handler = handler + client.ServiceName = s.Name + client.PackageName = util.SplitPackage(clientPackage, "") + client.FilePath = filepath.Join(clientDir, util.ToSnakeCase(s.Name)+".go") + if err := pkgGen.updateClient(client, clientTplName, client.FilePath, false); err != nil { + return fmt.Errorf("generate client %s failed, err: %v", client.FilePath, err.Error()) + } + } + + } + return nil +} + +func (pkgGen *HttpPackageGenerator) processHandler(handler *Handler, root *RouterNode, handlerDir, projectOutDir string, handlerByMethod bool) error { + singleHandlerPackage := "" + if handlerByMethod { + singleHandlerPackage = util.SubPackage(pkgGen.ProjPackage, filepath.Join(handlerDir, projectOutDir)) + } + handler.Imports = make(map[string]*model.Model, len(handler.Methods)) + for _, m := range handler.Methods { + // Iterate over the request and return parameters of the method to get import path. + for key, mm := range m.Models { + if v, ok := handler.Imports[mm.PackageName]; ok && v.Package != mm.Package { + handler.Imports[key] = mm + continue + } + handler.Imports[mm.PackageName] = mm + } + err := root.Update(m, handler.PackageName, singleHandlerPackage) + if err != nil { + return err + } + } + + if len(pkgGen.UseDir) != 0 { + oldModelDir := filepath.Clean(filepath.Join(pkgGen.ProjPackage, pkgGen.ModelDir)) + newModelDir := filepath.Clean(pkgGen.UseDir) + for _, m := range handler.Methods { + for _, mm := range m.Models { + mm.Package = strings.Replace(mm.Package, oldModelDir, newModelDir, 1) + } + } + } + + handler.Format() + return nil +} + +func (pkgGen *HttpPackageGenerator) updateHandler(handler interface{}, handlerTpl, filePath string, noRepeat bool) error { + if pkgGen.tplsInfo[handlerTpl].Disable { + return nil + } + isExist, err := util.PathExist(filePath) + if err != nil { + return err + } + if !isExist { + return pkgGen.TemplateGenerator.Generate(handler, handlerTpl, filePath, noRepeat) + } + if pkgGen.HandlerByMethod { // method by handler, do not need to insert new content + return nil + } + + file, err := ioutil.ReadFile(filePath) + if err != nil { + return err + } + + // insert new model imports + for alias, model := range handler.(Handler).Imports { + if bytes.Contains(file, []byte(model.Package)) { + continue + } + file, err = util.AddImportForContent(file, alias, model.Package) + if err != nil { + return err + } + } + // insert customized imports + if tplInfo, exist := pkgGen.TemplateGenerator.tplsInfo[handlerTpl]; exist { + if len(tplInfo.UpdateBehavior.ImportTpl) != 0 { + imptSlice, err := getInsertImportContent(tplInfo, handler, file) + if err != nil { + return err + } + for _, impt := range imptSlice { + if bytes.Contains(file, []byte(impt[1])) { + continue + } + file, err = util.AddImportForContent(file, impt[0], impt[1]) + if err != nil { + logs.Warnf("can not add import(%s) for file(%s), err: %v\n", impt[1], filePath, err) + } + } + } + } + + // insert new handler + for _, method := range handler.(Handler).Methods { + if bytes.Contains(file, []byte(fmt.Sprintf("func %s(", method.Name))) { + continue + } + + // Generate additional handlers using templates + handlerSingleTpl := pkgGen.tpls[handlerSingleTplName] + if handlerSingleTpl == nil { + return fmt.Errorf("tpl %s not found", handlerSingleTplName) + } + data := SingleHandler{ + HttpMethod: method, + FilePath: handler.(Handler).FilePath, + PackageName: handler.(Handler).PackageName, + ProjPackage: handler.(Handler).ProjPackage, + } + handlerFunc := bytes.NewBuffer(nil) + err = handlerSingleTpl.Execute(handlerFunc, data) + if err != nil { + return fmt.Errorf("execute template \"%s\" failed, %v", handlerSingleTplName, err) + } + + buf := bytes.NewBuffer(nil) + _, err = buf.Write(file) + if err != nil { + return fmt.Errorf("write handler \"%s\" failed, %v", method.Name, err) + } + _, err = buf.Write(handlerFunc.Bytes()) + if err != nil { + return fmt.Errorf("write handler \"%s\" failed, %v", method.Name, err) + } + file = buf.Bytes() + } + + pkgGen.files = append(pkgGen.files, File{filePath, string(file), false, ""}) + + return nil +} + +func (pkgGen *HttpPackageGenerator) updateClient(client interface{}, clientTpl, filePath string, noRepeat bool) error { + isExist, err := util.PathExist(filePath) + if err != nil { + return err + } + if !isExist { + return pkgGen.TemplateGenerator.Generate(client, clientTpl, filePath, noRepeat) + } + logs.Infof("Client file:%s has been generated, so don't update it", filePath) + + return nil +} + +func (m *HttpMethod) InitComment() { + text := strings.TrimLeft(strings.TrimSpace(m.Comment), "/") + if text == "" { + text = "// " + m.Name + " ." + } else if strings.HasPrefix(text, m.Name) { + text = "// " + text + } else { + text = "// " + m.Name + " " + text + } + text = strings.Replace(text, "\n", "\n// ", -1) + if !strings.Contains(text, "@router ") { + text += "\n// @router " + m.Path + } + m.Comment = text + " [" + m.HTTPMethod + "]" +} + +func MapSerializer(serializer string) string { + switch serializer { + case "json": + return "JSON" + case "thrift": + return "Thrift" + case "pb": + return "ProtoBuf" + default: + return "JSON" + } +} + +func (h *Handler) Format() { + for _, m := range h.Methods { + m.Serializer = MapSerializer(m.Serializer) + m.InitComment() + } +} diff --git a/generator/layout.go b/generator/layout.go new file mode 100644 index 0000000..ea189f5 --- /dev/null +++ b/generator/layout.go @@ -0,0 +1,232 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package generator + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "path/filepath" + "reflect" + "strings" + + "github.com/cloudwego/hertz/cmd/hz/meta" + "github.com/cloudwego/hertz/cmd/hz/util" + "gopkg.in/yaml.v2" +) + +// Layout contains the basic information of idl +type Layout struct { + OutDir string + GoModule string + ServiceName string + UseApacheThrift bool + HasIdl bool + NeedGoMod bool + ModelDir string + HandlerDir string + RouterDir string +} + +// LayoutGenerator contains the information generated by generating the layout template +type LayoutGenerator struct { + ConfigPath string + TemplateGenerator +} + +var ( + layoutConfig = defaultLayoutConfig + packageConfig = defaultPkgConfig +) + +func SetDefaultTemplateConfig() { + layoutConfig = defaultLayoutConfig + packageConfig = defaultPkgConfig +} + +func (lg *LayoutGenerator) Init() error { + config := layoutConfig + // unmarshal from user-defined config file if it exists + if lg.ConfigPath != "" { + cdata, err := ioutil.ReadFile(lg.ConfigPath) + if err != nil { + return fmt.Errorf("read layout config from %s failed, err: %v", lg.ConfigPath, err.Error()) + } + config = TemplateConfig{} + if err = yaml.Unmarshal(cdata, &config); err != nil { + return fmt.Errorf("unmarshal layout config failed, err: %v", err.Error()) + } + } + + if reflect.DeepEqual(config, TemplateConfig{}) { + return errors.New("empty config") + } + lg.Config = &config + + return lg.TemplateGenerator.Init() +} + +// checkInited initialize template definition +func (lg *LayoutGenerator) checkInited() error { + if lg.tpls == nil || lg.dirs == nil { + if err := lg.Init(); err != nil { + return fmt.Errorf("init layout config failed, err: %v", err.Error()) + } + } + return nil +} + +func (lg *LayoutGenerator) Generate(data map[string]interface{}) error { + if err := lg.checkInited(); err != nil { + return err + } + return lg.TemplateGenerator.Generate(data, "", "", false) +} + +func (lg *LayoutGenerator) GenerateByService(service Layout) error { + if err := lg.checkInited(); err != nil { + return err + } + + if len(service.HandlerDir) != 0 { + // override the default "biz/handler/ping.go" to "HANDLER_DIR/ping.go" + defaultPingDir := defaultHandlerDir + sp + "ping.go" + if tpl, exist := lg.tpls[defaultPingDir]; exist { + delete(lg.tpls, defaultPingDir) + newPingDir := filepath.Clean(service.HandlerDir + sp + "ping.go") + lg.tpls[newPingDir] = tpl + } + } + + if len(service.RouterDir) != 0 { + defaultRegisterDir := defaultRouterDir + sp + registerTplName + if tpl, exist := lg.tpls[defaultRegisterDir]; exist { + delete(lg.tpls, defaultRegisterDir) + newRegisterDir := filepath.Clean(service.RouterDir + sp + registerTplName) + lg.tpls[newRegisterDir] = tpl + } + } + + if !service.NeedGoMod { + gomodFile := "go.mod" + if _, exist := lg.tpls[gomodFile]; exist { + delete(lg.tpls, gomodFile) + } + } + + if util.IsWindows() { + buildSh := "build.sh" + bootstrapSh := defaultScriptDir + sp + "bootstrap.sh" + if _, exist := lg.tpls[buildSh]; exist { + delete(lg.tpls, buildSh) + } + if _, exist := lg.tpls[bootstrapSh]; exist { + delete(lg.tpls, bootstrapSh) + } + } + + sd, err := serviceToLayoutData(service) + if err != nil { + return err + } + + rd, err := serviceToRouterData(service) + if err != nil { + return err + } + if service.HasIdl { + for k := range lg.tpls { + if strings.Contains(k, registerTplName) { + delete(lg.tpls, k) + break + } + } + } + + data := map[string]interface{}{ + "*": sd, + layoutConfig.Layouts[routerIndex].Path: rd, // router.go + layoutConfig.Layouts[routerGenIndex].Path: rd, // router_gen.go + } + + return lg.Generate(data) +} + +// serviceToLayoutData stores go mod, serviceName, UseApacheThrift mapping +func serviceToLayoutData(service Layout) (map[string]interface{}, error) { + goMod := service.GoModule + if goMod == "" { + return nil, errors.New("please specify go-module") + } + handlerPkg := filepath.Base(defaultHandlerDir) + if len(service.HandlerDir) != 0 { + handlerPkg = filepath.Base(service.HandlerDir) + } + routerPkg := filepath.Base(defaultRouterDir) + if len(service.RouterDir) != 0 { + routerPkg = filepath.Base(service.RouterDir) + } + serviceName := service.ServiceName + if len(serviceName) == 0 { + serviceName = meta.DefaultServiceName + } + + return map[string]interface{}{ + "GoModule": goMod, + "ServiceName": serviceName, + "UseApacheThrift": service.UseApacheThrift, + "HandlerPkg": handlerPkg, + "RouterPkg": routerPkg, + }, nil +} + +// serviceToRouterData stores the registers function, router import path, handler import path +func serviceToRouterData(service Layout) (map[string]interface{}, error) { + routerDir := sp + defaultRouterDir + handlerDir := sp + defaultHandlerDir + if len(service.RouterDir) != 0 { + routerDir = sp + service.RouterDir + } + if len(service.HandlerDir) != 0 { + handlerDir = sp + service.HandlerDir + } + return map[string]interface{}{ + "Registers": []string{}, + "RouterPkgPath": service.GoModule + util.PathToImport(routerDir, ""), + "HandlerPkgPath": service.GoModule + util.PathToImport(handlerDir, ""), + }, nil +} + +func (lg *LayoutGenerator) GenerateByConfig(configPath string) error { + if err := lg.checkInited(); err != nil { + return err + } + buf, err := ioutil.ReadFile(configPath) + if err != nil { + return fmt.Errorf("read data file '%s' failed, err: %v", configPath, err.Error()) + } + var data map[string]interface{} + if err := json.Unmarshal(buf, &data); err != nil { + return fmt.Errorf("unmarshal json data failed, err: %v", err.Error()) + } + return lg.Generate(data) +} + +func (lg *LayoutGenerator) Degenerate() error { + return lg.TemplateGenerator.Degenerate() +} diff --git a/generator/layout_tpl.go b/generator/layout_tpl.go new file mode 100644 index 0000000..68a6d63 --- /dev/null +++ b/generator/layout_tpl.go @@ -0,0 +1,220 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package generator + +import "path/filepath" + +//-----------------------------------Default Layout----------------------------------------- + +const ( + sp = string(filepath.Separator) + + defaultBizDir = "biz" + defaultModelDir = "biz" + sp + "model" + defaultHandlerDir = "biz" + sp + "handler" + defaultServiceDir = "biz" + sp + "service" + defaultDalDir = "biz" + sp + "dal" + defaultScriptDir = "script" + defaultConfDir = "conf" + defaultRouterDir = "biz" + sp + "router" + defaultClientDir = "biz" + sp + "client" +) + +const ( + routerGenIndex = 8 + routerIndex = 9 + + RegisterFile = "router_gen.go" +) + +var defaultLayoutConfig = TemplateConfig{ + Layouts: []Template{ + { + Path: defaultDalDir + sp, + }, + { + Path: defaultHandlerDir + sp, + }, + { + Path: defaultModelDir + sp, + }, + { + Path: defaultServiceDir + sp, + }, + { + Path: "main.go", + Body: `// Code generated by hertz generator. + +package main + +import ( + "github.com/cloudwego/hertz/pkg/app/server" +) + +func main() { + h := server.Default() + + register(h) + h.Spin() +} + `, + }, + { + Path: "go.mod", + Delims: [2]string{"{{", "}}"}, + Body: `module {{.GoModule}} +{{- if .UseApacheThrift}} +replace github.com/apache/thrift => github.com/apache/thrift v0.13.0 +{{- end}} + `, + }, + { + Path: ".gitignore", + Body: `*.o +*.a +*.so +_obj +_test +*.[568vq] +[568vq].out +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* +_testmain.go +*.exe +*.exe~ +*.test +*.prof +*.rar +*.zip +*.gz +*.psd +*.bmd +*.cfg +*.pptx +*.log +*nohup.out +*settings.pyc +*.sublime-project +*.sublime-workspace +!.gitkeep +.DS_Store +/.idea +/.vscode +/output +*.local.yml +dumped_hertz_remote_config.json + `, + }, + { + Path: defaultHandlerDir + sp + "ping.go", + Body: `// Code generated by hertz generator. + +package {{.HandlerPkg}} + +import ( + "context" + + "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/common/utils" + "github.com/cloudwego/hertz/pkg/protocol/consts" +) + +// Ping . +func Ping(ctx context.Context, c *app.RequestContext) { + c.JSON(consts.StatusOK, utils.H{ + "message": "pong", + }) +} +`, + }, + { + Path: RegisterFile, + Body: `// Code generated by hertz generator. DO NOT EDIT. + +package main + +import ( + "github.com/cloudwego/hertz/pkg/app/server" + router "{{.RouterPkgPath}}" +) + +// register registers all routers. +func register(r *server.Hertz) { + + router.GeneratedRegister(r) + + customizedRegister(r) +} +`, + }, + { + Path: "router.go", + Body: `// Code generated by hertz generator. + +package main + +import ( + "github.com/cloudwego/hertz/pkg/app/server" + handler "{{.HandlerPkgPath}}" +) + +// customizeRegister registers customize routers. +func customizedRegister(r *server.Hertz){ + r.GET("/ping", handler.Ping) + + // your code ... +} +`, + }, + { + Path: defaultRouterDir + sp + registerTplName, + Body: `// Code generated by hertz generator. DO NOT EDIT. + +package {{.RouterPkg}} + +import ( + "github.com/cloudwego/hertz/pkg/app/server" +) + +// GeneratedRegister registers routers generated by IDL. +func GeneratedRegister(r *server.Hertz){ + ` + insertPointNew + ` +} +`, + }, + { + Path: "build.sh", + Body: `#!/bin/bash +RUN_NAME={{.ServiceName}} +mkdir -p output/bin +cp script/* output 2>/dev/null +chmod +x output/bootstrap.sh +go build -o output/bin/${RUN_NAME}`, + }, + { + Path: defaultScriptDir + sp + "bootstrap.sh", + Body: `#!/bin/bash +CURDIR=$(cd $(dirname $0); pwd) +BinaryName={{.ServiceName}} +echo "$CURDIR/bin/${BinaryName}" +exec $CURDIR/bin/${BinaryName}`, + }, + }, +} diff --git a/generator/model.go b/generator/model.go new file mode 100644 index 0000000..733c8a4 --- /dev/null +++ b/generator/model.go @@ -0,0 +1,161 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package generator + +import ( + "fmt" + "path/filepath" + "strings" + "text/template" + + "github.com/cloudwego/hertz/cmd/hz/generator/model" + "github.com/cloudwego/hertz/cmd/hz/generator/model/golang" + "github.com/cloudwego/hertz/cmd/hz/meta" + "github.com/cloudwego/hertz/cmd/hz/util" +) + +//---------------------------------Backend---------------------------------- + +type Option string + +const ( + OptionMarshalEnumToText Option = "MarshalEnumToText" + OptionTypedefAsTypeAlias Option = "TypedefAsTypeAlias" +) + +type Backend interface { + Template() (*template.Template, error) + List() map[string]string + SetOption(opts string) error + GetOptions() []string + Funcs(name string, fn interface{}) error +} + +type GolangBackend struct{} + +func (gb *GolangBackend) Template() (*template.Template, error) { + return golang.Template() +} + +func (gb *GolangBackend) List() map[string]string { + return golang.List() +} + +func (gb *GolangBackend) SetOption(opts string) error { + return golang.SetOption(opts) +} + +func (gb *GolangBackend) GetOptions() []string { + return golang.GetOptions() +} + +func (gb *GolangBackend) Funcs(name string, fn interface{}) error { + return golang.Funcs(name, fn) +} + +func switchBackend(backend meta.Backend) Backend { + switch backend { + case meta.BackendGolang: + return &GolangBackend{} + } + return loadThirdPartyBackend(string(backend)) +} + +func loadThirdPartyBackend(plugin string) Backend { + panic("no implement yet!") +} + +/**********************Generating*************************/ + +func (pkgGen *HttpPackageGenerator) LoadBackend(backend meta.Backend) error { + bd := switchBackend(backend) + if bd == nil { + return fmt.Errorf("no found backend '%s'", backend) + } + for _, opt := range pkgGen.Options { + if err := bd.SetOption(string(opt)); err != nil { + return fmt.Errorf("set option %s error, err: %v", opt, err.Error()) + } + } + + err := bd.Funcs("ROOT", func() *model.Model { + return pkgGen.curModel + }) + if err != nil { + return fmt.Errorf("register global function in model template failed, err: %v", err.Error()) + } + + tpl, err := bd.Template() + if err != nil { + return fmt.Errorf("load backend %s failed, err: %v", backend, err.Error()) + } + + if pkgGen.tpls == nil { + pkgGen.tpls = map[string]*template.Template{} + } + pkgGen.tpls[modelTplName] = tpl + pkgGen.loadedBackend = bd + return nil +} + +func (pkgGen *HttpPackageGenerator) GenModel(data *model.Model, gen bool) error { + if pkgGen.processedModels == nil { + pkgGen.processedModels = map[*model.Model]bool{} + } + + if _, ok := pkgGen.processedModels[data]; !ok { + var path string + var updatePackage bool + if strings.HasPrefix(data.Package, pkgGen.ProjPackage) && data.PackageName != pkgGen.ProjPackage { + path = data.Package[len(pkgGen.ProjPackage):] + } else { + path = data.Package + updatePackage = true + } + modelDir := util.SubDir(pkgGen.ModelDir, path) + if updatePackage { + data.Package = util.SubPackage(pkgGen.ProjPackage, modelDir) + } + data.FilePath = filepath.Join(modelDir, util.BaseNameAndTrim(data.FilePath)+".go") + + pkgGen.processedModels[data] = true + } + + for _, dep := range data.Imports { + if err := pkgGen.GenModel(dep, false); err != nil { + return fmt.Errorf("generate model %s failed, err: %v", dep.FilePath, err.Error()) + } + } + + if gen && !data.IsEmpty() { + pkgGen.curModel = data + removeDuplicateImport(data) + err := pkgGen.TemplateGenerator.Generate(data, modelTplName, data.FilePath, false) + pkgGen.curModel = nil + return err + } + return nil +} + +// Idls with the same Package do not need to refer to each other +func removeDuplicateImport(data *model.Model) { + for k, v := range data.Imports { + if data.Package == v.Package { + delete(data.Imports, k) + } + } +} diff --git a/generator/model/define.go b/generator/model/define.go new file mode 100644 index 0000000..ddde0eb --- /dev/null +++ b/generator/model/define.go @@ -0,0 +1,194 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model + +var ( + BaseTypes = []*Type{TypeBool, TypeByte, TypeInt8, TypeInt16, TypeInt32, TypeInt64, TypeUint8, TypeUint16, TypeUint32, TypeUint64, TypeFloat64, TypeString, TypeBinary} + ContainerTypes = []*Type{TypeBaseList, TypeBaseMap, TypeBaseSet} + BaseModel = Model{} +) + +var ( + TypeBool = &Type{ + Name: "bool", + Scope: &BaseModel, + Kind: KindBool, + } + TypeByte = &Type{ + Name: "int8", + Scope: &BaseModel, + Kind: KindInt8, + } + TypePbByte = &Type{ + Name: "byte", + Scope: &BaseModel, + Kind: KindInt8, + } + TypeUint8 = &Type{ + Name: "uint8", + Scope: &BaseModel, + Kind: KindInt8, + } + TypeUint16 = &Type{ + Name: "uint16", + Scope: &BaseModel, + Kind: KindInt16, + } + TypeUint32 = &Type{ + Name: "uint32", + Scope: &BaseModel, + Kind: KindInt32, + } + TypeUint64 = &Type{ + Name: "uint64", + Scope: &BaseModel, + Kind: KindInt64, + } + TypeUint = &Type{ + Name: "uint", + Scope: &BaseModel, + Kind: KindInt, + } + TypeInt8 = &Type{ + Name: "int8", + Scope: &BaseModel, + Kind: KindInt8, + } + TypeInt16 = &Type{ + Name: "int16", + Scope: &BaseModel, + Kind: KindInt16, + } + TypeInt32 = &Type{ + Name: "int32", + Scope: &BaseModel, + Kind: KindInt32, + } + TypeInt64 = &Type{ + Name: "int64", + Scope: &BaseModel, + Kind: KindInt64, + } + TypeInt = &Type{ + Name: "int", + Scope: &BaseModel, + Kind: KindInt, + } + TypeFloat32 = &Type{ + Name: "float32", + Scope: &BaseModel, + Kind: KindFloat64, + } + TypeFloat64 = &Type{ + Name: "float64", + Scope: &BaseModel, + Kind: KindFloat64, + } + TypeString = &Type{ + Name: "string", + Scope: &BaseModel, + Kind: KindString, + } + TypeBinary = &Type{ + Name: "binary", + Scope: &BaseModel, + Kind: KindSlice, + Category: CategoryBinary, + Extra: []*Type{TypePbByte}, + } + + TypeBaseMap = &Type{ + Name: "map", + Scope: &BaseModel, + Kind: KindMap, + Category: CategoryMap, + } + TypeBaseSet = &Type{ + Name: "set", + Scope: &BaseModel, + Kind: KindSlice, + Category: CategorySet, + } + TypeBaseList = &Type{ + Name: "list", + Scope: &BaseModel, + Kind: KindSlice, + Category: CategoryList, + } +) + +func NewCategoryType(typ *Type, cg Category) *Type { + cyp := *typ + cyp.Category = cg + return &cyp +} + +func NewStructType(name string, cg Category) *Type { + return &Type{ + Name: name, + Scope: nil, + Kind: KindStruct, + Category: cg, + Indirect: false, + Extra: nil, + HasNew: true, + } +} + +func NewFuncType(name string, cg Category) *Type { + return &Type{ + Name: name, + Scope: nil, + Kind: KindFunc, + Category: cg, + Indirect: false, + Extra: nil, + HasNew: false, + } +} + +func IsBaseType(typ *Type) bool { + for _, t := range BaseTypes { + if typ == t { + return true + } + } + return false +} + +func NewEnumType(name string, cg Category) *Type { + return &Type{ + Name: name, + Scope: &BaseModel, + Kind: KindInt, + Category: cg, + Indirect: false, + Extra: nil, + HasNew: true, + } +} + +func NewOneofType(name string) *Type { + return &Type{ + Name: name, + Scope: &BaseModel, + Kind: KindInterface, + Indirect: false, + Extra: nil, + HasNew: true, + } +} diff --git a/generator/model/expr.go b/generator/model/expr.go new file mode 100644 index 0000000..ab9d31a --- /dev/null +++ b/generator/model/expr.go @@ -0,0 +1,95 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model + +import ( + "fmt" + "strconv" +) + +type BoolExpression struct { + Src bool +} + +func (boolExpr BoolExpression) Expression() string { + if boolExpr.Src { + return "true" + } else { + return "false" + } +} + +type StringExpression struct { + Src string +} + +func (stringExpr StringExpression) Expression() string { + return fmt.Sprintf("%q", stringExpr.Src) +} + +type NumberExpression struct { + Src string +} + +func (numExpr NumberExpression) Expression() string { + return numExpr.Src +} + +type ListExpression struct { + ElementType *Type + Elements []Literal +} + +type IntExpression struct { + Src int +} + +func (intExpr IntExpression) Expression() string { + return strconv.Itoa(intExpr.Src) +} + +type DoubleExpression struct { + Src float64 +} + +func (doubleExpr DoubleExpression) Expression() string { + return strconv.FormatFloat(doubleExpr.Src, 'f', -1, 64) +} + +func (listExpr ListExpression) Expression() string { + ret := "[]" + listExpr.ElementType.Name + "{\n" + for _, e := range listExpr.Elements { + ret += e.Expression() + ",\n" + } + ret += "\n}" + return ret +} + +type MapExpression struct { + KeyType *Type + ValueType *Type + Elements map[string]Literal +} + +func (mapExpr MapExpression) Expression() string { + ret := "map[" + mapExpr.KeyType.Name + "]" + mapExpr.ValueType.Name + "{\n" + for k, e := range mapExpr.Elements { + ret += fmt.Sprintf("%q: %s,\n", k, e.Expression()) + } + ret += "\n}" + return ret +} diff --git a/generator/model/golang/constant.go b/generator/model/golang/constant.go new file mode 100644 index 0000000..a17bec3 --- /dev/null +++ b/generator/model/golang/constant.go @@ -0,0 +1,23 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package golang + +var constants = ` +{{define "Constants"}} +const {{.Name}} {{.Type.ResolveName ROOT}} = {{.Value.Expression}} +{{end}} +` diff --git a/generator/model/golang/enum.go b/generator/model/golang/enum.go new file mode 100644 index 0000000..1bc0bb6 --- /dev/null +++ b/generator/model/golang/enum.go @@ -0,0 +1,67 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package golang + +// Enum . +var enum = ` +{{define "Enum"}} +{{- $EnumType := (Identify .Name)}} +type {{$EnumType}} {{.GoType}} + +const ( + {{- range $i, $e := .Values}} + {{$EnumType}}_{{$e.Name}} {{$EnumType}} = {{$e.Value.Expression}} + {{- end}} +) + +func (p {{$EnumType}}) String() string { + switch p { + {{- range $i, $e := .Values}} + case {{$EnumType}}_{{$e.Name}}: + return "{{printf "%s%s" $EnumType $e.Name | SnakeCase}}" + {{- end}} + } + return "" +} + +func {{$EnumType}}FromString(s string) ({{$EnumType}}, error) { + switch s { + {{- range $i, $e := .Values}} + case "{{printf "%s%s" $EnumType $e.Name | SnakeCase}}": + return {{$EnumType}}_{{$e.Name}}, nil + {{- end}} + } + return {{$EnumType}}(0), fmt.Errorf("not a valid {{$EnumType}} string") +} + +{{- if Features.MarshalEnumToText}} + +func (p {{$EnumType}}) MarshalText() ([]byte, error) { + return []byte(p.String()), nil +} + +func (p *{{$EnumType}}) UnmarshalText(text []byte) error { + q, err := {{$EnumType}}FromString(string(text)) + if err != nil { + return err + } + *p = q + return nil +} +{{- end}} +{{end}} +` diff --git a/generator/model/golang/file.go b/generator/model/golang/file.go new file mode 100644 index 0000000..cccaf54 --- /dev/null +++ b/generator/model/golang/file.go @@ -0,0 +1,62 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package golang + +var file = `{{$ROOT := . -}} +// Code generated by hz. + +package {{.PackageName}} + +import ( + "fmt" +{{- range $alias, $model := .Imports}} + {{$model.PackageName}} "{{$model.Package}}" +{{- end}} +) + +{{- range .Typedefs}} +{{template "Typedef" .}} +{{- end}} + +{{- range .Constants}} +{{template "Constants" .}} +{{- end}} + +{{- range .Variables}} +{{template "Variables" .}} +{{- end}} + +{{- range .Functions}} +{{template "Function" .}} +{{- end}} + +{{- range .Enums}} +{{template "Enum" .}} +{{- end}} + +{{- range .Oneofs}} +{{template "Oneof" .}} +{{- end}} + +{{- range .Structs}} +{{template "Struct" .}} +{{- end}} + +{{- range .Methods}} +{{template "Method" .}} +{{- end}} +` diff --git a/generator/model/golang/function.go b/generator/model/golang/function.go new file mode 100644 index 0000000..25f1a1d --- /dev/null +++ b/generator/model/golang/function.go @@ -0,0 +1,46 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package golang + +var function = ` +{{define "Function"}} +func {{template "FuncBody" . -}} +{{end}}{{/* define "Function" */}} + +{{define "FuncBody"}} +{{- .Name -}}( +{{- range $i, $arg := .Args -}} +{{- if gt $i 0}}, {{end -}} +{{$arg.Name}} {{$arg.Type.ResolveName ROOT}} +{{- end -}}{{/* range */}}) +{{- if gt (len .Rets) 0}} ({{end -}} +{{- range $i, $ret := .Rets -}} +{{- if gt $i 0}}, {{end -}} +{{$ret.Type.ResolveName ROOT}} +{{- end -}}{{/* range */}} +{{- if gt (len .Rets) 0}}) {{end -}}{ +{{.Code}} +} +{{end}}{{/* define "FuncBody" */}} +` + +var method = ` +{{define "Method"}} +func ({{.ReceiverName}} {{.ReceiverType.ResolveName ROOT}}) +{{- template "FuncBody" .Function -}} +{{end}} +` diff --git a/generator/model/golang/init.go b/generator/model/golang/init.go new file mode 100644 index 0000000..55b834e --- /dev/null +++ b/generator/model/golang/init.go @@ -0,0 +1,132 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package golang + +import ( + "fmt" + "strings" + "text/template" +) + +var tpls *template.Template + +var list = map[string]string{ + "file": file, + "typedef": typedef, + "constants": constants, + "variables": variables, + "function": function, + "enum": enum, + "struct": structLike, + "method": method, + "oneof": oneof, +} + +/***********************Export API*******************************/ + +func Template() (*template.Template, error) { + if tpls != nil { + return tpls, nil + } + tpls = new(template.Template) + + tpls = tpls.Funcs(funcMap) + + var err error + for k, li := range list { + tpls, err = tpls.Parse(li) + if err != nil { + return nil, fmt.Errorf("parse template '%s' failed, err: %v", k, err.Error()) + } + } + return tpls, nil +} + +func List() map[string]string { + return list +} + +/***********************Template Funcs**************************/ + +var funcMap = template.FuncMap{ + "Features": getFeatures, + "Identify": identify, + "CamelCase": camelCase, + "SnakeCase": snakeCase, + "GetTypedefReturnStr": getTypedefReturnStr, +} + +func Funcs(name string, fn interface{}) error { + if _, ok := funcMap[name]; ok { + return fmt.Errorf("duplicate function: %s has been registered", name) + } + funcMap[name] = fn + return nil +} + +func identify(name string) string { + return name +} + +func camelCase(name string) string { + return name +} + +func snakeCase(name string) string { + return name +} + +func getTypedefReturnStr(name string) string { + if strings.Contains(name, ".") { + idx := strings.LastIndex(name, ".") + return name[:idx] + "." + "New" + name[idx+1:] + "()" + + } + return "New" + name + "()" +} + +/***********************Template Options**************************/ + +type feature struct { + MarshalEnumToText bool + TypedefAsTypeAlias bool +} + +var features = feature{} + +func getFeatures() feature { + return features +} + +func SetOption(opt string) error { + switch opt { + case "MarshalEnumToText": + features.MarshalEnumToText = true + case "TypedefAsTypeAlias": + features.TypedefAsTypeAlias = true + } + return nil +} + +var Options = []string{ + "MarshalEnumToText", + "TypedefAsTypeAlias", +} + +func GetOptions() []string { + return Options +} diff --git a/generator/model/golang/oneof.go b/generator/model/golang/oneof.go new file mode 100644 index 0000000..f2bc96c --- /dev/null +++ b/generator/model/golang/oneof.go @@ -0,0 +1,45 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package golang + +var oneof = ` +{{define "Oneof"}} +type {{$.InterfaceName}} interface { + {{$.InterfaceName}}() +} + +{{range $i, $f := .Choices}} +type {{$f.MessageName}}_{{$f.ChoiceName}} struct { + {{$f.ChoiceName}} {{$f.Type.ResolveName ROOT}} +} +{{end}} + +{{range $i, $f := .Choices}} +func (*{{$f.MessageName}}_{{$f.ChoiceName}}) {{$.InterfaceName}}() {} +{{end}} + +{{range $i, $f := .Choices}} +func (p *{{$f.MessageName}}) Get{{$f.ChoiceName}}() {{$f.Type.ResolveName ROOT}} { + if p, ok := p.Get{{$.OneofName}}().(*{{$f.MessageName}}_{{$f.ChoiceName}}); ok { + return p.{{$f.ChoiceName}} + } + return {{$f.Type.ResolveDefaultValue}} +} +{{end}} + +{{end}} +` diff --git a/generator/model/golang/struct.go b/generator/model/golang/struct.go new file mode 100644 index 0000000..58cafcc --- /dev/null +++ b/generator/model/golang/struct.go @@ -0,0 +1,120 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package golang + +// StructLike is the code template for struct, union, and exception. +var structLike = ` +{{define "Struct"}} +{{- $TypeName := (Identify .Name) -}} +{{$MessageLeadingComments := .LeadingComments}} +{{if ne (len $MessageLeadingComments) 0}} +//{{$MessageLeadingComments}} +{{end -}} +type {{$TypeName}} struct { +{{- range $i, $f := .Fields}} +{{- $FieldLeadingComments := $f.LeadingComments}} +{{$FieldTrailingComments := $f.TrailingComments -}} +{{- if ne (len $FieldLeadingComments) 0 -}} + //{{$FieldLeadingComments}} +{{end -}} +{{- if $f.IsPointer -}} + {{$f.Name}} *{{$f.Type.ResolveName ROOT}} {{$f.GenGoTags}}{{if ne (len $FieldTrailingComments) 0}} //{{$FieldTrailingComments}}{{end -}} +{{- else -}} + {{$f.Name}} {{$f.Type.ResolveName ROOT}} {{$f.GenGoTags}}{{if ne (len $FieldTrailingComments) 0}} //{{$FieldTrailingComments}}{{end -}} +{{- end -}} +{{- end}} +} + +func New{{$TypeName}}() *{{$TypeName}} { + return &{{$TypeName}}{ + {{template "StructLikeDefault" .}} + } +} + +{{template "FieldGetOrSet" .}} + +{{if eq .Category 14}} +func (p *{{$TypeName}}) CountSetFields{{$TypeName}}() int { + count := 0 + {{- range $i, $f := .Fields}} + {{- if $f.Type.IsSettable}} + if p.IsSet{{$f.Name}}() { + count++ + } + {{- end}} + {{- end}} + return count +} +{{- end}} + +func (p *{{$TypeName}}) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("{{$TypeName}}(%+v)", *p) +} + +{{- if eq .Category 15}} +func (p *{{$TypeName}}) Error() string { + return p.String() +} +{{- end}} +{{- end}}{{/* define "StructLike" */}} + +{{- define "StructLikeDefault"}} +{{- range $i, $f := .Fields}} + {{- if $f.IsSetDefault}} + {{$f.Name}}: {{$f.DefaultValue.Expression}}, + {{- end}} +{{- end}} +{{- end -}}{{/* define "StructLikeDefault" */}} + +{{- define "FieldGetOrSet"}} +{{- $TypeName := (Identify .Name)}} +{{- range $i, $f := .Fields}} +{{$FieldName := $f.Name}} +{{$FieldTypeName := $f.Type.ResolveName ROOT}} + +{{- if $f.Type.IsSettable}} +func (p *{{$TypeName}}) IsSet{{$FieldName}}() bool { + return p.{{$FieldName}} != nil +} +{{- end}}{{/* IsSettable . */}} + +func (p *{{$TypeName}}) Get{{$FieldName}}() {{$FieldTypeName}} { + {{- if $f.Type.IsSettable}} + if !p.IsSet{{$FieldName}}() { + return {{with $f.DefaultValue}}{{$f.DefaultValue.Expression}}{{else}}nil{{end}} + } + {{- end}} +{{- if $f.IsPointer}} + return *p.{{$FieldName}} +{{else}} + return p.{{$FieldName}} +{{- end -}} +} + +func (p *{{$TypeName}}) Set{{$FieldName}}(val {{$FieldTypeName}}) { +{{- if $f.IsPointer}} + *p.{{$FieldName}} = val +{{else}} + p.{{$FieldName}} = val +{{- end -}} +} +{{- end}}{{/* range .Fields */}} +{{- end}}{{/* define "FieldGetOrSet" */}} +` diff --git a/generator/model/golang/typedef.go b/generator/model/golang/typedef.go new file mode 100644 index 0000000..79eaa86 --- /dev/null +++ b/generator/model/golang/typedef.go @@ -0,0 +1,32 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package golang + +// Typedef . +var typedef = ` +{{define "Typedef"}} +{{- $NewTypeName := (Identify .Alias)}} +{{- $OldTypeName := .Type.ResolveNameForTypedef ROOT}} +type {{$NewTypeName}} = {{$OldTypeName}} + +{{if eq .Type.Kind 25}}{{if .Type.HasNew}} +func New{{$NewTypeName}}() *{{$NewTypeName}} { + return {{(GetTypedefReturnStr $OldTypeName)}} +} +{{- end}}{{- end}} +{{- end}} +` diff --git a/generator/model/golang/variable.go b/generator/model/golang/variable.go new file mode 100644 index 0000000..ea7b3a8 --- /dev/null +++ b/generator/model/golang/variable.go @@ -0,0 +1,23 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package golang + +var variables = ` +{{- define "Variables"}} +var {{.Name}} {{.Type.ResolveName ROOT}} = {{.Value.Expression}} +{{end}} +` diff --git a/generator/model/model.go b/generator/model/model.go new file mode 100644 index 0000000..f402804 --- /dev/null +++ b/generator/model/model.go @@ -0,0 +1,417 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model + +import ( + "errors" + "fmt" + "strings" +) + +type Kind uint + +const ( + KindInvalid Kind = iota + KindBool + KindInt + KindInt8 + KindInt16 + KindInt32 + KindInt64 + KindUint + KindUint8 + KindUint16 + KindUint32 + KindUint64 + KindUintptr + KindFloat32 + KindFloat64 + KindComplex64 + KindComplex128 + KindArray + KindChan + KindFunc + KindInterface + KindMap + KindPtr + KindSlice + KindString + KindStruct + KindUnsafePointer +) + +type Category int64 + +const ( + CategoryConstant Category = 1 + CategoryBinary Category = 8 + CategoryMap Category = 9 + CategoryList Category = 10 + CategorySet Category = 11 + CategoryEnum Category = 12 + CategoryStruct Category = 13 + CategoryUnion Category = 14 + CategoryException Category = 15 + CategoryTypedef Category = 16 + CategoryService Category = 17 +) + +type Model struct { + FilePath string + Package string + Imports map[string]*Model //{{import}}:Model + + // rendering data + PackageName string + // Imports map[string]string //{{alias}}:{{import}} + Typedefs []TypeDef + Constants []Constant + Variables []Variable + Functions []Function + Enums []Enum + Structs []Struct + Methods []Method + Oneofs []Oneof +} + +func (m Model) IsEmpty() bool { + return len(m.Typedefs) == 0 && len(m.Constants) == 0 && len(m.Variables) == 0 && + len(m.Functions) == 0 && len(m.Enums) == 0 && len(m.Structs) == 0 && len(m.Methods) == 0 +} + +type Models []*Model + +func (a *Models) MergeMap(b map[string]*Model) { + for _, v := range b { + insert := true + for _, p := range *a { + if p == v { + insert = false + } + } + if insert { + *a = append(*a, v) + } + } + return +} + +func (a *Models) MergeArray(b []*Model) { + for _, v := range b { + insert := true + for _, p := range *a { + if p == v { + insert = false + } + } + if insert { + *a = append(*a, v) + } + } + return +} + +type RequiredNess int + +const ( + RequiredNess_Default RequiredNess = 0 + RequiredNess_Required RequiredNess = 1 + RequiredNess_Optional RequiredNess = 2 +) + +type Type struct { + Name string + Scope *Model + Kind Kind + Indirect bool + Category Category + Extra []*Type // [{key_type},{value_type}] for map, [{element_type}] for list or set + HasNew bool +} + +func (rt *Type) ResolveDefaultValue() string { + if rt == nil { + return "" + } + switch rt.Kind { + case KindInt, KindInt8, KindInt16, KindInt32, KindInt64, KindUint, KindUint16, KindUint32, KindUint64, + KindFloat32, KindFloat64, KindComplex64, KindComplex128: + return "0" + case KindBool: + return "false" + case KindString: + return "\"\"" + default: + return "nil" + } +} + +func (rt *Type) ResolveNameForTypedef(scope *Model) (string, error) { + if rt == nil { + return "", errors.New("type is nil") + } + name := rt.Name + if rt.Scope == nil { + return rt.Name, nil + } + + switch rt.Kind { + case KindArray, KindSlice: + if len(rt.Extra) != 1 { + return "", fmt.Errorf("the type: %s should have 1 extra type, but has %d", rt.Name, len(rt.Extra)) + } + resolveName, err := rt.Extra[0].ResolveName(scope) + if err != nil { + return "", err + } + name = fmt.Sprintf("[]%s", resolveName) + case KindMap: + if len(rt.Extra) != 2 { + return "", fmt.Errorf("the type: %s should have 2 extra types, but has %d", rt.Name, len(rt.Extra)) + } + resolveKey, err := rt.Extra[0].ResolveName(scope) + if err != nil { + return "", err + } + resolveValue, err := rt.Extra[1].ResolveName(scope) + if err != nil { + return "", err + } + name = fmt.Sprintf("map[%s]%s", resolveKey, resolveValue) + case KindChan: + if len(rt.Extra) != 1 { + return "", fmt.Errorf("the type: %s should have 1 extra type, but has %d", rt.Name, len(rt.Extra)) + } + resolveName, err := rt.Extra[0].ResolveName(scope) + if err != nil { + return "", err + } + name = fmt.Sprintf("chan %s", resolveName) + } + + if scope != nil && rt.Scope != &BaseModel && rt.Scope.Package != scope.Package { + name = rt.Scope.PackageName + "." + name + } + return name, nil +} + +func (rt *Type) ResolveName(scope *Model) (string, error) { + if rt == nil { + return "", fmt.Errorf("type is nil") + } + name := rt.Name + if rt.Scope == nil { + if rt.Kind == KindStruct { + return "*" + rt.Name, nil + } + return rt.Name, nil + } + + if rt.Category == CategoryTypedef { + if scope != nil && rt.Scope != &BaseModel && rt.Scope.Package != scope.Package { + name = rt.Scope.PackageName + "." + name + } + + if rt.Kind == KindStruct { + name = "*" + name + } + + return name, nil + } + + switch rt.Kind { + case KindArray, KindSlice: + if len(rt.Extra) != 1 { + return "", fmt.Errorf("The type: %s should have 1 extra type, but has %d", rt.Name, len(rt.Extra)) + } + resolveName, err := rt.Extra[0].ResolveName(scope) + if err != nil { + return "", err + } + name = fmt.Sprintf("[]%s", resolveName) + case KindMap: + if len(rt.Extra) != 2 { + return "", fmt.Errorf("The type: %s should have 2 extra type, but has %d", rt.Name, len(rt.Extra)) + } + resolveKey, err := rt.Extra[0].ResolveName(scope) + if err != nil { + return "", err + } + resolveValue, err := rt.Extra[1].ResolveName(scope) + if err != nil { + return "", err + } + name = fmt.Sprintf("map[%s]%s", resolveKey, resolveValue) + case KindChan: + if len(rt.Extra) != 1 { + return "", fmt.Errorf("The type: %s should have 1 extra type, but has %d", rt.Name, len(rt.Extra)) + } + resolveName, err := rt.Extra[0].ResolveName(scope) + if err != nil { + return "", err + } + name = fmt.Sprintf("chan %s", resolveName) + } + + if scope != nil && rt.Scope != &BaseModel && rt.Scope.Package != scope.Package { + name = rt.Scope.PackageName + "." + name + } + + if rt.Kind == KindStruct { + name = "*" + name + } + return name, nil +} + +func (rt *Type) IsBinary() bool { + return rt.Category == CategoryBinary && (rt.Kind == KindSlice || rt.Kind == KindArray) +} + +func (rt *Type) IsBaseType() bool { + return rt.Kind < KindComplex64 +} + +func (rt *Type) IsSettable() bool { + switch rt.Kind { + case KindArray, KindChan, KindFunc, KindInterface, KindMap, KindPtr, KindSlice, KindUnsafePointer: + return true + } + return false +} + +type TypeDef struct { + Scope *Model + Alias string + Type *Type +} + +type Constant struct { + Scope *Model + Name string + Type *Type + Value Literal +} + +type Literal interface { + Expression() string +} + +type Variable struct { + Scope *Model + Name string + Type *Type + Value Literal +} + +type Function struct { + Scope *Model + Name string + Args []Variable + Rets []Variable + Code string +} + +type Method struct { + Scope *Model + ReceiverName string + ReceiverType *Type + ByPtr bool + Function +} + +type Enum struct { + Scope *Model + Name string + GoType string + Values []Constant +} + +type Struct struct { + Scope *Model + Name string + Fields []Field + Category Category + LeadingComments string +} + +type Field struct { + Scope *Struct + Name string + Type *Type + IsSetDefault bool + DefaultValue Literal + Required RequiredNess + Tags Tags + LeadingComments string + TrailingComments string + IsPointer bool +} + +type Oneof struct { + MessageName string + OneofName string + InterfaceName string + Choices []Choice +} + +type Choice struct { + MessageName string + ChoiceName string + Type *Type +} + +type Tags []Tag + +type Tag struct { + Key string + Value string + IsDefault bool // default tag +} + +func (ts Tags) String() string { + ret := make([]string, 0, len(ts)) + for _, t := range ts { + ret = append(ret, fmt.Sprintf("%v:%q", t.Key, t.Value)) + } + return strings.Join(ret, " ") +} + +func (ts *Tags) Remove(name string) { + ret := make([]Tag, 0, len(*ts)) + for _, t := range *ts { + if t.Key != name { + ret = append(ret, t) + } + } + *ts = ret +} + +func (ts Tags) Len() int { return len(ts) } + +func (ts Tags) Less(i, j int) bool { + return ts[i].Key < ts[j].Key +} + +func (ts Tags) Swap(i, j int) { ts[i], ts[j] = ts[j], ts[i] } + +func (f Field) GenGoTags() string { + if len(f.Tags) == 0 { + return "" + } + + return fmt.Sprintf("`%s`", f.Tags.String()) +} diff --git a/generator/model_test.go b/generator/model_test.go new file mode 100644 index 0000000..9e11f52 --- /dev/null +++ b/generator/model_test.go @@ -0,0 +1,240 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package generator + +import ( + "testing" + "text/template" + + "github.com/cloudwego/hertz/cmd/hz/generator/model" + "github.com/cloudwego/hertz/cmd/hz/meta" +) + +type StringValue struct { + src string +} + +func (sv *StringValue) Expression() string { + return sv.src +} + +func TestIdlGenerator_GenModel(t *testing.T) { + typeModel := &model.Type{ + Name: "Model", + Kind: model.KindStruct, + Indirect: true, + } + typeErr := &model.Type{ + Name: "error", + Kind: model.KindInterface, + Indirect: false, + } + + type fields struct { + ConfigPath string + OutputDir string + Backend meta.Backend + handlerDir string + routerDir string + modelDir string + ProjPackage string + Config *TemplateConfig + tpls map[string]*template.Template + } + type args struct { + data *model.Model + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "", + fields: fields{ + OutputDir: "./testdata", + Backend: meta.BackendGolang, + }, + args: args{ + data: &model.Model{ + FilePath: "idl/main.thrift", + Package: "model/psm", + PackageName: "psm", + Imports: map[string]*model.Model{ + "base": { + Package: "model/base", + PackageName: "base", + }, + }, + Typedefs: []model.TypeDef{ + { + Alias: "HerztModel", + Type: typeModel, + }, + }, + Constants: []model.Constant{ + { + Name: "OBJ", + Type: typeErr, + Value: &StringValue{"fmt.Errorf(\"EOF\")"}, + }, + }, + Variables: []model.Variable{ + { + Name: "Object", + Type: typeModel, + Value: &StringValue{"&Model{}"}, + }, + }, + Functions: []model.Function{ + { + Name: "Init", + Args: nil, + Rets: []model.Variable{ + { + Name: "err", + Type: typeErr, + }, + }, + Code: "return nil", + }, + }, + Enums: []model.Enum{ + { + Name: "Sex", + Values: []model.Constant{ + { + Name: "Male", + Type: &model.Type{ + Name: "int", + Kind: model.KindInt, + Indirect: false, + Category: 1, + }, + Value: &StringValue{"1"}, + }, + { + Name: "Femal", + Type: &model.Type{ + Name: "int", + Kind: model.KindInt, + Indirect: false, + Category: 1, + }, + Value: &StringValue{"2"}, + }, + }, + }, + }, + Structs: []model.Struct{ + { + Name: "Model", + Fields: []model.Field{ + { + Name: "A", + Type: &model.Type{ + Name: "[]byte", + Kind: model.KindSlice, + Indirect: false, + Category: model.CategoryBinary, + }, + IsSetDefault: true, + DefaultValue: &StringValue{"[]byte(\"\")"}, + }, + { + Name: "B", + Type: &model.Type{ + Name: "Base", + Kind: model.KindStruct, + Indirect: false, + }, + }, + }, + Category: model.CategoryUnion, + }, + }, + Methods: []model.Method{ + { + ReceiverName: "self", + ReceiverType: typeModel, + ByPtr: true, + Function: model.Function{ + Name: "Bind", + Args: []model.Variable{ + { + Name: "c", + Type: &model.Type{ + Name: "RequestContext", + Scope: &model.Model{ + PackageName: "hertz", + }, + Kind: model.KindStruct, + Indirect: true, + }, + }, + }, + Rets: []model.Variable{ + { + Name: "error", + Type: typeErr, + }, + }, + Code: "return nil", + }, + }, + }, + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + self := &HttpPackageGenerator{ + ConfigPath: tt.fields.ConfigPath, + Backend: tt.fields.Backend, + HandlerDir: tt.fields.handlerDir, + RouterDir: tt.fields.routerDir, + ModelDir: tt.fields.modelDir, + ProjPackage: tt.fields.ProjPackage, + TemplateGenerator: TemplateGenerator{ + OutputDir: tt.fields.OutputDir, + Config: tt.fields.Config, + tpls: tt.fields.tpls, + }, + Options: []Option{ + OptionTypedefAsTypeAlias, + OptionMarshalEnumToText, + }, + } + + err := self.LoadBackend(meta.BackendGolang) + if err != nil { + t.Fatal(err) + } + + if err := self.GenModel(tt.args.data, true); (err != nil) != tt.wantErr { + t.Errorf("IdlGenerator.GenModel() error = %v, wantErr %v", err, tt.wantErr) + } + if err := self.Persist(); err != nil { + t.Fatal(err) + } + }) + } +} diff --git a/generator/package.go b/generator/package.go new file mode 100644 index 0000000..01a6884 --- /dev/null +++ b/generator/package.go @@ -0,0 +1,199 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package generator + +import ( + "errors" + "fmt" + "io/ioutil" + "path/filepath" + "reflect" + "text/template" + + "github.com/cloudwego/hertz/cmd/hz/generator/model" + "github.com/cloudwego/hertz/cmd/hz/meta" + "github.com/cloudwego/hertz/cmd/hz/util" + "gopkg.in/yaml.v2" +) + +type HttpPackage struct { + IdlName string + Package string + Services []*Service + Models []*model.Model + RouterInfo *Router +} + +type Service struct { + Name string + Methods []*HttpMethod + ClientMethods []*ClientMethod + Models []*model.Model // all dependency models + BaseDomain string // base domain for client code + ServiceGroup string // service level router group + ServiceGenDir string // handler_dir for handler_by_service +} + +// HttpPackageGenerator is used to record the configuration related to generating hertz http code. +type HttpPackageGenerator struct { + ConfigPath string // package template path + Backend meta.Backend // model template + Options []Option + CmdType string + ProjPackage string // go module for project + HandlerDir string + RouterDir string + ModelDir string + UseDir string // model dir for third repo + ClientDir string // client dir for "new"/"update" command + IdlClientDir string // client dir for "client" command + ForceClientDir string // client dir without namespace for "client" command + BaseDomain string // request domain for "client" command + ServiceGenDir string + + NeedModel bool + HandlerByMethod bool // generate handler files with method dimension + SnakeStyleMiddleware bool // use snake name style for middleware + + loadedBackend Backend + curModel *model.Model + processedModels map[*model.Model]bool + + TemplateGenerator +} + +func (pkgGen *HttpPackageGenerator) Init() error { + defaultConfig := packageConfig + customConfig := TemplateConfig{} + // unmarshal from user-defined config file if it exists + if pkgGen.ConfigPath != "" { + cdata, err := ioutil.ReadFile(pkgGen.ConfigPath) + if err != nil { + return fmt.Errorf("read layout config from %s failed, err: %v", pkgGen.ConfigPath, err.Error()) + } + if err = yaml.Unmarshal(cdata, &customConfig); err != nil { + return fmt.Errorf("unmarshal layout config failed, err: %v", err.Error()) + } + if reflect.DeepEqual(customConfig, TemplateConfig{}) { + return errors.New("empty config") + } + } + + if pkgGen.tpls == nil { + pkgGen.tpls = make(map[string]*template.Template, len(defaultConfig.Layouts)) + } + if pkgGen.tplsInfo == nil { + pkgGen.tplsInfo = make(map[string]*Template, len(defaultConfig.Layouts)) + } + + // extract routerTplName/middlewareTplName/handlerTplName/registerTplName/modelTplName/clientTplName directories + // load default template + for _, layout := range defaultConfig.Layouts { + // default template use "fileName" as template name + path := filepath.Base(layout.Path) + err := pkgGen.loadLayout(layout, path, true) + if err != nil { + return err + } + } + + // override the default template, other customized file template will be loaded by "TemplateGenerator.Init" + for _, layout := range customConfig.Layouts { + if !IsDefaultPackageTpl(layout.Path) { + continue + } + err := pkgGen.loadLayout(layout, layout.Path, true) + if err != nil { + return err + } + } + + pkgGen.Config = &customConfig + // load Model tpl if need + if pkgGen.Backend != "" { + if err := pkgGen.LoadBackend(pkgGen.Backend); err != nil { + return fmt.Errorf("load model template failed, err: %v", err.Error()) + } + } + + pkgGen.processedModels = make(map[*model.Model]bool) + pkgGen.TemplateGenerator.isPackageTpl = true + + return pkgGen.TemplateGenerator.Init() +} + +func (pkgGen *HttpPackageGenerator) checkInited() (bool, error) { + if pkgGen.tpls == nil { + if err := pkgGen.Init(); err != nil { + return false, fmt.Errorf("init layout config failed, err: %v", err.Error()) + } + } + return pkgGen.ConfigPath == "", nil +} + +func (pkgGen *HttpPackageGenerator) Generate(pkg *HttpPackage) error { + if _, err := pkgGen.checkInited(); err != nil { + return err + } + if len(pkg.Models) != 0 { + for _, m := range pkg.Models { + if err := pkgGen.GenModel(m, pkgGen.NeedModel); err != nil { + return fmt.Errorf("generate model %s failed, err: %v", m.FilePath, err.Error()) + } + } + } + + if pkgGen.CmdType == meta.CmdClient { + // default client dir + clientDir := pkgGen.IdlClientDir + // user specify client dir + if len(pkgGen.ClientDir) != 0 { + clientDir = pkgGen.ClientDir + } + if err := pkgGen.genClient(pkg, clientDir); err != nil { + return err + } + if err := pkgGen.genCustomizedFile(pkg); err != nil { + return err + } + return nil + } + + // this is for handler_by_service, the handler_dir is {$HANDLER_DIR}/{$PKG} + handlerDir := util.SubDir(pkgGen.HandlerDir, pkg.Package) + if pkgGen.HandlerByMethod { + handlerDir = pkgGen.HandlerDir + } + handlerPackage := util.SubPackage(pkgGen.ProjPackage, handlerDir) + routerDir := util.SubDir(pkgGen.RouterDir, pkg.Package) + routerPackage := util.SubPackage(pkgGen.ProjPackage, routerDir) + + root := NewRouterTree() + if err := pkgGen.genHandler(pkg, handlerDir, handlerPackage, root); err != nil { + return err + } + + if err := pkgGen.genRouter(pkg, root, handlerPackage, routerDir, routerPackage); err != nil { + return err + } + + if err := pkgGen.genCustomizedFile(pkg); err != nil { + return err + } + + return nil +} diff --git a/generator/package_tpl.go b/generator/package_tpl.go new file mode 100644 index 0000000..839794a --- /dev/null +++ b/generator/package_tpl.go @@ -0,0 +1,978 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package generator + +var ( + routerTplName = "router.go" + middlewareTplName = "middleware.go" + middlewareSingleTplName = "middleware_single.go" + handlerTplName = "handler.go" + handlerSingleTplName = "handler_single.go" + modelTplName = "model.go" + registerTplName = "register.go" + clientTplName = "client.go" // generate a default client for server + hertzClientTplName = "hertz_client.go" // underlying client for client command + idlClientName = "idl_client.go" // client of service for quick call + + insertPointNew = "//INSERT_POINT: DO NOT DELETE THIS LINE!" + insertPointPatternNew = `//INSERT_POINT\: DO NOT DELETE THIS LINE\!` +) + +var templateNameSet = map[string]string{ + routerTplName: routerTplName, + middlewareTplName: middlewareTplName, + middlewareSingleTplName: middlewareSingleTplName, + handlerTplName: handlerTplName, + handlerSingleTplName: handlerSingleTplName, + modelTplName: modelTplName, + registerTplName: registerTplName, + clientTplName: clientTplName, + hertzClientTplName: hertzClientTplName, + idlClientName: idlClientName, +} + +func IsDefaultPackageTpl(name string) bool { + if _, exist := templateNameSet[name]; exist { + return true + } + + return false +} + +var defaultPkgConfig = TemplateConfig{ + Layouts: []Template{ + { + Path: defaultHandlerDir + sp + handlerTplName, + Delims: [2]string{"{{", "}}"}, + Body: `// Code generated by hertz generator. + +package {{.PackageName}} + +import ( + "context" + + "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/protocol/consts" + +{{- range $k, $v := .Imports}} + {{$k}} "{{$v.Package}}" +{{- end}} +) + +{{range $_, $MethodInfo := .Methods}} +{{$MethodInfo.Comment}} +func {{$MethodInfo.Name}}(ctx context.Context, c *app.RequestContext) { + var err error + {{if ne $MethodInfo.RequestTypeName "" -}} + var req {{$MethodInfo.RequestTypeName}} + err = c.BindAndValidate(&req) + if err != nil { + c.String(consts.StatusBadRequest, err.Error()) + return + } + {{end}} + resp := new({{$MethodInfo.ReturnTypeName}}) + + c.{{.Serializer}}(consts.StatusOK, resp) +} +{{end}} + `, + }, + { + Path: defaultRouterDir + sp + routerTplName, + Delims: [2]string{"{{", "}}"}, + Body: `// Code generated by hertz generator. DO NOT EDIT. + +package {{$.PackageName}} + +import ( + "github.com/cloudwego/hertz/pkg/app/server" + + {{- range $k, $v := .HandlerPackages}} + {{$k}} "{{$v}}" + {{- end}} +) + +/* + This file will register all the routes of the services in the master idl. + And it will update automatically when you use the "update" command for the idl. + So don't modify the contents of the file, or your code will be deleted when it is updated. + */ + +{{define "g"}} +{{- if eq .Path "/"}}r +{{- else}}{{.GroupName}}{{end}} +{{- end}} + +{{define "G"}} +{{- if ne .Handler ""}} + {{- .GroupName}}.{{.HttpMethod}}("{{.Path}}", append({{.HandlerMiddleware}}Mw(), {{.Handler}})...) +{{- end}} +{{- if ne (len .Children) 0}} +{{.MiddleWare}} := {{template "g" .}}.Group("{{.Path}}", {{.GroupMiddleware}}Mw()...) +{{- end}} +{{- range $_, $router := .Children}} +{{- if ne .Handler ""}} + {{template "G" $router}} +{{- else}} + { {{template "G" $router}} + } +{{- end}} +{{- end}} +{{- end}} + +// Register register routes based on the IDL 'api.${HTTP Method}' annotation. +func Register{{$.IdlName}}(r *server.Hertz) { +{{template "G" .Router}} +} + + `, + }, + { + Path: defaultRouterDir + sp + registerTplName, + Body: `// Code generated by hertz generator. DO NOT EDIT. + +package {{.PackageName}} + +import ( + "github.com/cloudwego/hertz/pkg/app/server" + {{$.DepPkgAlias}} "{{$.DepPkg}}" +) + +// GeneratedRegister registers routers generated by IDL. +func GeneratedRegister(r *server.Hertz){ + ` + insertPointNew + ` + {{$.DepPkgAlias}}.{{$.RegisterName}}(r) +} +`, + }, + // Model tpl is imported by model generator. Here only decides model directory. + { + Path: defaultModelDir + sp + modelTplName, + Body: ``, + }, + { + Path: defaultRouterDir + sp + middlewareTplName, + Delims: [2]string{"{{", "}}"}, + Body: `// Code generated by hertz generator. + +package {{$.PackageName}} + +import ( + "github.com/cloudwego/hertz/pkg/app" +) + +{{define "M"}} +{{- if ne .Children.Len 0}} +func {{.GroupMiddleware}}Mw() []app.HandlerFunc { + // your code... + return nil +} +{{end}} +{{- if ne .Handler ""}} +func {{.HandlerMiddleware}}Mw() []app.HandlerFunc { + // your code... + return nil +} +{{end}} +{{range $_, $router := $.Children}}{{template "M" $router}}{{end}} +{{- end}} + +{{template "M" .Router}} + + `, + }, + { + Path: defaultClientDir + sp + clientTplName, + Delims: [2]string{"{{", "}}"}, + Body: `// Code generated by hertz generator. + +package {{$.PackageName}} + +import ( + "github.com/cloudwego/hertz/pkg/app/client" + "github.com/cloudwego/hertz/pkg/common/config" +) + +type {{.ServiceName}}Client struct { + client * client.Client +} + +func New{{.ServiceName}}Client(opt ...config.ClientOption) (*{{.ServiceName}}Client, error) { + c, err := client.NewClient(opt...) + if err != nil { + return nil, err + } + + return &{{.ServiceName}}Client{ + client: c, + }, nil +} + `, + }, + { + Path: defaultHandlerDir + sp + handlerSingleTplName, + Delims: [2]string{"{{", "}}"}, + Body: ` +{{.Comment}} +func {{.Name}}(ctx context.Context, c *app.RequestContext) { + var err error + {{if ne .RequestTypeName "" -}} + var req {{.RequestTypeName}} + err = c.BindAndValidate(&req) + if err != nil { + c.String(consts.StatusBadRequest, err.Error()) + return + } + {{end}} + resp := new({{.ReturnTypeName}}) + + c.{{.Serializer}}(consts.StatusOK, resp) +} +`, + }, + { + Path: defaultRouterDir + sp + middlewareSingleTplName, + Delims: [2]string{"{{", "}}"}, + Body: ` +func {{.MiddleWare}}Mw() []app.HandlerFunc { + // your code... + return nil +} +`, + }, + { + Path: defaultRouterDir + sp + hertzClientTplName, + Delims: [2]string{"{{", "}}"}, + Body: hertzClientTpl, + }, + { + Path: defaultRouterDir + sp + idlClientName, + Delims: [2]string{"{{", "}}"}, + Body: idlClientTpl, + }, + }, +} + +var hertzClientTpl = `// Code generated by hz. + +package {{.PackageName}} + +import ( + "context" + "encoding/json" + "encoding/xml" + "fmt" + "io" + "net/http" + "net/url" + "reflect" + "regexp" + "strings" + + hertz_client "github.com/cloudwego/hertz/pkg/app/client" + "github.com/cloudwego/hertz/pkg/common/config" + "github.com/cloudwego/hertz/pkg/common/errors" + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/protocol/client" +) + +type use interface { + Use(mws ...hertz_client.Middleware) +} + +// Definition of global data and types. +type ResponseResultDecider func(statusCode int, rawResponse *protocol.Response) (isError bool) + +type ( + bindRequestBodyFunc func(c *cli, r *request) (contentType string, body io.Reader, err error) + beforeRequestFunc func(*cli, *request) error + afterResponseFunc func(*cli, *response) error +) + +var ( + hdrContentTypeKey = http.CanonicalHeaderKey("Content-Type") + hdrContentEncodingKey = http.CanonicalHeaderKey("Content-Encoding") + + plainTextType = "text/plain; charset=utf-8" + jsonContentType = "application/json; charset=utf-8" + formContentType = "multipart/form-data" + + jsonCheck = regexp.MustCompile(` + "`(?i:(application|text)/(json|.*\\+json|json\\-.*)(; |$))`)\n" + + `xmlCheck = regexp.MustCompile(` + "`(?i:(application|text)/(xml|.*\\+xml)(; |$))`)\n" + + ` +) + +// Configuration of client +type Option struct { + f func(*Options) +} + +type Options struct { + hostUrl string + enumAsInt bool + doer client.Doer + header http.Header + requestBodyBind bindRequestBodyFunc + responseResultDecider ResponseResultDecider + middlewares []hertz_client.Middleware + clientOption []config.ClientOption +} + +func getOptions(ops ...Option) *Options { + opts := &Options{} + for _, do := range ops { + do.f(opts) + } + return opts +} + +// WithHertzClientOption is used to pass configuration for the hertz client +func WithHertzClientOption(opt ...config.ClientOption) Option { + return Option{func(op *Options) { + op.clientOption = append(op.clientOption, opt...) + }} +} + +// WithHertzClientMiddleware is used to register the middleware for the hertz client +func WithHertzClientMiddleware(mws ...hertz_client.Middleware) Option { + return Option{func(op *Options) { + op.middlewares = append(op.middlewares, mws...) + }} +} + +// WithHertzClient is used to register a custom hertz client +func WithHertzClient(client client.Doer) Option { + return Option{func(op *Options) { + op.doer = client + }} +} + +// WithHeader is used to add the default header, which is carried by every request +func WithHeader(header http.Header) Option { + return Option{func(op *Options) { + op.header = header + }} +} + +// WithResponseResultDecider configure custom deserialization of http response to response struct +func WithResponseResultDecider(decider ResponseResultDecider) Option { + return Option{func(op *Options) { + op.responseResultDecider = decider + }} +} + +// WithQueryEnumAsInt is used to set enum as int for query parameters +func WithQueryEnumAsInt(enable bool) Option { + return Option{func(op *Options) { + op.enumAsInt = enable + }} +} + +func withHostUrl(HostUrl string) Option { + return Option{func(op *Options) { + op.hostUrl = HostUrl + }} +} + +// underlying client +type cli struct { + hostUrl string + enumAsInt bool + doer client.Doer + header http.Header + bindRequestBody bindRequestBodyFunc + responseResultDecider ResponseResultDecider + + beforeRequest []beforeRequestFunc + afterResponse []afterResponseFunc +} + +func (c *cli) Use(mws ...hertz_client.Middleware) error { + u, ok := c.doer.(use) + if !ok { + return errors.NewPublic("doer does not support middleware, choose the right doer.") + } + u.Use(mws...) + return nil +} + +func newClient(opts *Options) (*cli, error) { + if opts.requestBodyBind == nil { + opts.requestBodyBind = defaultRequestBodyBind + } + if opts.responseResultDecider == nil { + opts.responseResultDecider = defaultResponseResultDecider + } + if opts.doer == nil { + cli, err := hertz_client.NewClient(opts.clientOption...) + if err != nil { + return nil, err + } + opts.doer = cli + } + + c := &cli{ + hostUrl: opts.hostUrl, + enumAsInt: opts.enumAsInt, + doer: opts.doer, + header: opts.header, + bindRequestBody: opts.requestBodyBind, + responseResultDecider: opts.responseResultDecider, + beforeRequest: []beforeRequestFunc{ + parseRequestURL, + parseRequestHeader, + createHTTPRequest, + }, + afterResponse: []afterResponseFunc{ + parseResponseBody, + }, + } + + if len(opts.middlewares) != 0 { + if err := c.Use(opts.middlewares...); err != nil { + return nil, err + } + } + return c, nil +} + +func (c *cli) execute(req *request) (*response, error) { + var err error + for _, f := range c.beforeRequest { + if err = f(c, req); err != nil { + return nil, err + } + } + + if hostHeader := req.header.Get("Host"); hostHeader != "" { + req.rawRequest.Header.SetHost(hostHeader) + } + + resp := protocol.Response{} + + err = c.doer.Do(req.ctx, req.rawRequest, &resp) + + response := &response{ + request: req, + rawResponse: &resp, + } + + if err != nil { + return response, err + } + + body, err := resp.BodyE() + if err != nil { + return nil, err + } + + if strings.EqualFold(resp.Header.Get(hdrContentEncodingKey), "gzip") && resp.Header.ContentLength() != 0 { + body, err = resp.BodyGunzip() + if err != nil { + return nil, err + } + } + + response.bodyByte = body + + response.size = int64(len(response.bodyByte)) + + // Apply Response middleware + for _, f := range c.afterResponse { + if err = f(c, response); err != nil { + break + } + } + + return response, err +} + +// r get request +func (c *cli) r() *request { + return &request{ + queryParam: url.Values{}, + header: http.Header{}, + pathParam: map[string]string{}, + formParam: map[string]string{}, + fileParam: map[string]string{}, + client: c, + } +} + +type response struct { + request *request + rawResponse *protocol.Response + + bodyByte []byte + size int64 +} + +// statusCode method returns the HTTP status code for the executed request. +func (r *response) statusCode() int { + if r.rawResponse == nil { + return 0 + } + + return r.rawResponse.StatusCode() +} + +// body method returns HTTP response as []byte array for the executed request. +func (r *response) body() []byte { + if r.rawResponse == nil { + return []byte{} + } + return r.bodyByte +} + +// Header method returns the response headers +func (r *response) header() http.Header { + if r.rawResponse == nil { + return http.Header{} + } + h := http.Header{} + r.rawResponse.Header.VisitAll(func(key, value []byte) { + h.Add(string(key), string(value)) + }) + + return h +} + +type request struct { + client *cli + url string + method string + queryParam url.Values + header http.Header + pathParam map[string]string + formParam map[string]string + fileParam map[string]string + bodyParam interface{} + rawRequest *protocol.Request + ctx context.Context + requestOptions []config.RequestOption + result interface{} + Error interface{} +} + +func (r *request) setContext(ctx context.Context) *request { + r.ctx = ctx + return r +} + +func (r *request) context() context.Context { + return r.ctx +} + +func (r *request) setHeader(header, value string) *request { + r.header.Set(header, value) + return r +} + +func (r *request) addHeader(header, value string) *request { + r.header.Add(header, value) + return r +} + +func (r *request) addHeaders(params map[string]string) *request { + for k, v := range params { + r.addHeader(k, v) + } + return r +} + + +func (r *request) setQueryParam(param string, value interface{}) *request { + v := reflect.ValueOf(value) + switch v.Kind() { + case reflect.Slice, reflect.Array: + for index := 0; index < v.Len(); index++ { + r.queryParam.Add(param, fmt.Sprint(v.Index(index).Interface())) + } + case reflect.Int32, reflect.Int64: + if r.client.enumAsInt { + r.queryParam.Add(param, fmt.Sprintf("%d", v.Interface())) + } else { + r.queryParam.Add(param, fmt.Sprint(v)) + } + default: + r.queryParam.Set(param, fmt.Sprint(v)) + } + return r +} + +func (r *request) setResult(res interface{}) *request { + r.result = res + return r +} + +func (r *request) setError(err interface{}) *request { + r.Error = err + return r +} + +func (r *request) setHeaders(headers map[string]string) *request { + for h, v := range headers { + r.setHeader(h, v) + } + + return r +} + +func (r *request) setQueryParams(params map[string]interface{}) *request { + for p, v := range params { + r.setQueryParam(p, v) + } + + return r +} + +func (r *request) setPathParams(params map[string]string) *request { + for p, v := range params { + r.pathParam[p] = v + } + return r +} + +func (r *request) setFormParams(params map[string]string) *request { + for p, v := range params { + r.formParam[p] = v + } + return r +} + +func (r *request) setFormFileParams(params map[string]string) *request { + for p, v := range params { + r.fileParam[p] = v + } + return r +} + +func (r *request) setBodyParam(body interface{}) *request { + r.bodyParam = body + return r +} + +func (r *request) setRequestOption(option ...config.RequestOption) *request { + r.requestOptions = append(r.requestOptions, option...) + return r +} + +func (r *request) execute(method, url string) (*response, error) { + r.method = method + r.url = url + return r.client.execute(r) +} + +func parseRequestURL(c *cli, r *request) error { + if len(r.pathParam) > 0 { + for p, v := range r.pathParam { + r.url = strings.Replace(r.url, ":"+p, url.PathEscape(v), -1) + } + } + + // Parsing request URL + reqURL, err := url.Parse(r.url) + if err != nil { + return err + } + + // If request.URL is relative path then added c.HostURL into + // the request URL otherwise request.URL will be used as-is + if !reqURL.IsAbs() { + r.url = reqURL.String() + if len(r.url) > 0 && r.url[0] != '/' { + r.url = "/" + r.url + } + + reqURL, err = url.Parse(c.hostUrl + r.url) + if err != nil { + return err + } + } + + // Adding Query Param + query := make(url.Values) + + for k, v := range r.queryParam { + // remove query param from client level by key + // since overrides happens for that key in the request + query.Del(k) + for _, iv := range v { + query.Add(k, iv) + } + } + + if len(query) > 0 { + if isStringEmpty(reqURL.RawQuery) { + reqURL.RawQuery = query.Encode() + } else { + reqURL.RawQuery = reqURL.RawQuery + "&" + query.Encode() + } + } + + r.url = reqURL.String() + + return nil +} + +func isStringEmpty(str string) bool { + return len(strings.TrimSpace(str)) == 0 +} + +func parseRequestHeader(c *cli, r *request) error { + hdr := make(http.Header) + if c.header != nil { + for k := range c.header { + hdr[k] = append(hdr[k], c.header[k]...) + } + } + + for k := range r.header { + hdr.Del(k) + hdr[k] = append(hdr[k], r.header[k]...) + } + + if len(r.formParam) != 0 || len(r.fileParam) != 0 { + hdr.Add(hdrContentTypeKey, formContentType) + } + + r.header = hdr + return nil +} + +// detectContentType method is used to figure out "request.Body" content type for request header +func detectContentType(body interface{}) string { + contentType := plainTextType + kind := reflect.Indirect(reflect.ValueOf(body)).Kind() + switch kind { + case reflect.Struct, reflect.Map: + contentType = jsonContentType + case reflect.String: + contentType = plainTextType + default: + if b, ok := body.([]byte); ok { + contentType = http.DetectContentType(b) + } else if kind == reflect.Slice { + contentType = jsonContentType + } + } + + return contentType +} + +func defaultRequestBodyBind(c *cli, r *request) (contentType string, body io.Reader, err error) { + if !isPayloadSupported(r.method) { + return + } + var bodyBytes []byte + contentType = r.header.Get(hdrContentTypeKey) + if isStringEmpty(contentType) { + contentType = detectContentType(r.bodyParam) + r.header.Set(hdrContentTypeKey, contentType) + } + kind := reflect.Indirect(reflect.ValueOf(r.bodyParam)).Kind() + if isJSONType(contentType) && + (kind == reflect.Struct || kind == reflect.Map || kind == reflect.Slice) { + bodyBytes, err = json.Marshal(r.bodyParam) + } else if isXMLType(contentType) && (kind == reflect.Struct) { + bodyBytes, err = xml.Marshal(r.bodyParam) + } + if err != nil { + return + } + return contentType, strings.NewReader(string(bodyBytes)), nil +} + +func isPayloadSupported(m string) bool { + return !(m == http.MethodHead || m == http.MethodOptions || m == http.MethodGet || m == http.MethodDelete) +} + +func createHTTPRequest(c *cli, r *request) (err error) { + contentType, body, err := c.bindRequestBody(c, r) + if !isStringEmpty(contentType) { + r.header.Set(hdrContentTypeKey, contentType) + } + if err == nil { + r.rawRequest = protocol.NewRequest(r.method, r.url, body) + if contentType == formContentType && isPayloadSupported(r.method) { + if r.rawRequest.IsBodyStream() { + r.rawRequest.ResetBody() + } + r.rawRequest.SetMultipartFormData(r.formParam) + r.rawRequest.SetFiles(r.fileParam) + } + for key, values := range r.header { + for _, val := range values { + r.rawRequest.Header.Add(key, val) + } + } + r.rawRequest.SetOptions(r.requestOptions...) + } + return err +} + +func silently(_ ...interface{}) {} + +// defaultResponseResultDecider method returns true if HTTP status code >= 400 otherwise false. +func defaultResponseResultDecider(statusCode int, rawResponse *protocol.Response) bool { + return statusCode > 399 +} + +// IsJSONType method is to check JSON content type or not +func isJSONType(ct string) bool { + return jsonCheck.MatchString(ct) +} + +// IsXMLType method is to check XML content type or not +func isXMLType(ct string) bool { + return xmlCheck.MatchString(ct) +} + +func parseResponseBody(c *cli, res *response) (err error) { + if res.statusCode() == http.StatusNoContent { + return + } + // Handles only JSON or XML content type + ct := res.header().Get(hdrContentTypeKey) + + isError := c.responseResultDecider(res.statusCode(), res.rawResponse) + if isError { + if res.request.Error != nil { + if isJSONType(ct) || isXMLType(ct) { + err = unmarshalContent(ct, res.bodyByte, res.request.Error) + } + } else { + jsonByte, jsonErr := json.Marshal(map[string]interface{}{ + "status_code": res.rawResponse.StatusCode(), + "body": string(res.bodyByte), + }) + if jsonErr != nil { + return jsonErr + } + err = fmt.Errorf(string(jsonByte)) + } + } else if res.request.result != nil { + if isJSONType(ct) || isXMLType(ct) { + err = unmarshalContent(ct, res.bodyByte, res.request.result) + return + } + } + return +} + +// unmarshalContent content into object from JSON or XML +func unmarshalContent(ct string, b []byte, d interface{}) (err error) { + if isJSONType(ct) { + err = json.Unmarshal(b, d) + } else if isXMLType(ct) { + err = xml.Unmarshal(b, d) + } + + return +} + +` + +var idlClientTpl = `// Code generated by hertz generator. + +package {{.PackageName}} + +import ( + "context" + "fmt" + + "github.com/cloudwego/hertz/pkg/common/config" + "github.com/cloudwego/hertz/pkg/protocol" +{{- range $k, $v := .Imports}} + {{$k}} "{{$v.Package}}" +{{- end}} +) + +// unused protection +var ( + _ = fmt.Formatter(nil) +) + +type Client interface { + {{range $_, $MethodInfo := .ClientMethods}} + {{$MethodInfo.Name}}(context context.Context, req *{{$MethodInfo.RequestTypeName}}, reqOpt ...config.RequestOption) (resp *{{$MethodInfo.ReturnTypeName}}, rawResponse *protocol.Response, err error) + {{end}} +} + +type {{.ServiceName}}Client struct { + client *cli +} + +func New{{.ServiceName}}Client(hostUrl string, ops ...Option) (Client, error) { + opts := getOptions(append(ops, withHostUrl(hostUrl))...) + cli, err := newClient(opts) + if err != nil { + return nil, err + } + return &{{.ServiceName}}Client{ + client: cli, + }, nil +} + +{{range $_, $MethodInfo := .ClientMethods}} +func (s *{{$.ServiceName}}Client) {{$MethodInfo.Name}}(context context.Context, req *{{$MethodInfo.RequestTypeName}}, reqOpt ...config.RequestOption) (resp *{{$MethodInfo.ReturnTypeName}}, rawResponse *protocol.Response, err error) { + httpResp := &{{$MethodInfo.ReturnTypeName}}{} + ret, err := s.client.r(). + setContext(context). + setQueryParams(map[string]interface{}{ + {{$MethodInfo.QueryParamsCode}} + }). + setPathParams(map[string]string{ + {{$MethodInfo.PathParamsCode}} + }). + addHeaders(map[string]string{ + {{$MethodInfo.HeaderParamsCode}} + }). + setFormParams(map[string]string{ + {{$MethodInfo.FormValueCode}} + }). + setFormFileParams(map[string]string{ + {{$MethodInfo.FormFileCode}} + }). + {{$MethodInfo.BodyParamsCode}} + setRequestOption(reqOpt...). + setResult(httpResp). + execute("{{if EqualFold $MethodInfo.HTTPMethod "Any"}}POST{{else}}{{ $MethodInfo.HTTPMethod }}{{end}}", "{{$MethodInfo.Path}}") + if err != nil { + return nil, nil, err + } + + resp = httpResp + rawResponse = ret.rawResponse + return resp, rawResponse, nil +} +{{end}} + +var defaultClient, _ = New{{.ServiceName}}Client("{{.BaseDomain}}") + +func ConfigDefaultClient(ops ...Option) (err error) { + defaultClient, err = New{{.ServiceName}}Client("{{.BaseDomain}}", ops...) + return +} + +{{range $_, $MethodInfo := .ClientMethods}} +func {{$MethodInfo.Name}}(context context.Context, req *{{$MethodInfo.RequestTypeName}}, reqOpt ...config.RequestOption) (resp *{{$MethodInfo.ReturnTypeName}}, rawResponse *protocol.Response, err error) { + return defaultClient.{{$MethodInfo.Name}}(context, req, reqOpt...) +} +{{end}} +` diff --git a/generator/router.go b/generator/router.go new file mode 100644 index 0000000..7caa525 --- /dev/null +++ b/generator/router.go @@ -0,0 +1,472 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package generator + +import ( + "bytes" + "fmt" + "io/ioutil" + "path/filepath" + "regexp" + "sort" + "strings" + + "github.com/cloudwego/hertz/cmd/hz/util" +) + +type Router struct { + FilePath string + PackageName string + HandlerPackages map[string]string // {{basename}}:{{import_path}} + Router *RouterNode + IdlName string +} + +type RouterNode struct { + GroupName string // current group name(the parent middleware name), used to register route. example: {{.GroupName}}.{{HttpMethod}} + MiddleWare string // current node middleware, used to be group name for children. + HandlerMiddleware string + GroupMiddleware string + PathPrefix string + + Path string + Parent *RouterNode + Children childrenRouterInfo + + Handler string // {{HandlerPackage}}.{{HandlerName}} + HandlerPackage string + HandlerPackageAlias string + HttpMethod string +} + +type RegisterInfo struct { + PackageName string + DepPkgAlias string + DepPkg string + RegisterName string +} + +// NewRouterTree contains "/" as root node +func NewRouterTree() *RouterNode { + return &RouterNode{ + GroupName: "root", + MiddleWare: "root", + GroupMiddleware: "root", + Path: "/", + Parent: nil, + } +} + +func (routerNode *RouterNode) Sort() { + sort.Sort(routerNode.Children) +} + +func (routerNode *RouterNode) Update(method *HttpMethod, handlerType, handlerPkg string) error { + if method.Path == "" { + return fmt.Errorf("empty path for method '%s'", method.Name) + } + paths := strings.Split(method.Path, "/") + if paths[0] == "" { + paths = paths[1:] + } + parent, last := routerNode.FindNearest(paths) + if last == len(paths) { + return fmt.Errorf("path '%s' has been registered", method.Path) + } + name := util.ToVarName(paths[:last]) + parent.Insert(name, method, handlerType, paths[last:], handlerPkg) + parent.Sort() + return nil +} + +func (routerNode *RouterNode) RawHandlerName() string { + parts := strings.Split(routerNode.Handler, ".") + handlerName := parts[len(parts)-1] + return handlerName +} + +// DyeGroupName traverses the routing tree in depth and names the handler/group middleware for each node. +// If snakeStyleMiddleware is set to true, the name style of the middleware will use snake name style. +func (routerNode *RouterNode) DyeGroupName(snakeStyleMiddleware bool) error { + groups := []string{"root"} + + hook := func(layer int, node *RouterNode) error { + node.GroupName = groups[layer] + if node.MiddleWare == "" { + pname := node.Path + if len(pname) > 1 && pname[0] == '/' { + pname = pname[1:] + } + + if node.Parent != nil { + node.PathPrefix = node.Parent.PathPrefix + "_" + util.ToGoFuncName(pname) + } else { + node.PathPrefix = "_" + util.ToGoFuncName(pname) + } + + handlerMiddlewareName := "" + isLeafNode := false + if len(node.Handler) != 0 { + handlerMiddlewareName = node.RawHandlerName() + // If it is a leaf node, then "group middleware name" and "handler middleware name" are the same + if len(node.Children) == 0 { + pname = handlerMiddlewareName + isLeafNode = true + } + } + + pname = convertToMiddlewareName(pname) + handlerMiddlewareName = convertToMiddlewareName(handlerMiddlewareName) + + if isLeafNode { + name, err := util.GetMiddlewareUniqueName(pname) + if err != nil { + return fmt.Errorf("get unique name for middleware '%s' failed, err: %v", name, err) + } + pname = name + handlerMiddlewareName = name + } else { + var err error + pname, err = util.GetMiddlewareUniqueName(pname) + if err != nil { + return fmt.Errorf("get unique name for middleware '%s' failed, err: %v", pname, err) + } + handlerMiddlewareName, err = util.GetMiddlewareUniqueName(handlerMiddlewareName) + if err != nil { + return fmt.Errorf("get unique name for middleware '%s' failed, err: %v", handlerMiddlewareName, err) + } + } + node.MiddleWare = "_" + pname + if len(node.Handler) != 0 { + node.HandlerMiddleware = "_" + handlerMiddlewareName + if snakeStyleMiddleware { + node.HandlerMiddleware = "_" + node.RawHandlerName() + } + } + node.GroupMiddleware = node.MiddleWare + if snakeStyleMiddleware { + node.GroupMiddleware = node.PathPrefix + } + } + if layer >= len(groups)-1 { + groups = append(groups, node.MiddleWare) + } else { + groups[layer+1] = node.MiddleWare + } + return nil + } + + // Deep traversal from the 0th level of the routing tree. + err := routerNode.DFS(0, hook) + return err +} + +func (routerNode *RouterNode) DFS(i int, hook func(layer int, node *RouterNode) error) error { + if routerNode == nil { + return nil + } + err := hook(i, routerNode) + if err != nil { + return err + } + for _, n := range routerNode.Children { + err = n.DFS(i+1, hook) + if err != nil { + return err + } + } + return nil +} + +var handlerPkgMap map[string]string + +func (routerNode *RouterNode) Insert(name string, method *HttpMethod, handlerType string, paths []string, handlerPkg string) { + cur := routerNode + for i, p := range paths { + c := &RouterNode{ + Path: "/" + p, + Parent: cur, + } + if i == len(paths)-1 { + // generate handler by method + if len(handlerPkg) != 0 { + // get a unique package alias for every handler + pkgAlias := filepath.Base(handlerPkg) + pkgAlias = util.ToVarName([]string{pkgAlias}) + val, exist := handlerPkgMap[handlerPkg] + if !exist { + pkgAlias, _ = util.GetHandlerPackageUniqueName(pkgAlias) + if len(handlerPkgMap) == 0 { + handlerPkgMap = make(map[string]string, 10) + } + handlerPkgMap[handlerPkg] = pkgAlias + } else { + pkgAlias = val + } + c.HandlerPackageAlias = pkgAlias + c.Handler = pkgAlias + "." + method.Name + c.HandlerPackage = handlerPkg + method.RefPackage = c.HandlerPackage + method.RefPackageAlias = c.HandlerPackageAlias + } else { // generate handler by service + c.Handler = handlerType + "." + method.Name + } + c.HttpMethod = getHttpMethod(method.HTTPMethod) + } + if cur.Children == nil { + cur.Children = make([]*RouterNode, 0, 1) + } + cur.Children = append(cur.Children, c) + cur = c + } +} + +func getHttpMethod(method string) string { + if strings.EqualFold(method, "Any") { + return "Any" + } + return strings.ToUpper(method) +} + +func (routerNode *RouterNode) FindNearest(paths []string) (*RouterNode, int) { + ns := len(paths) + cur := routerNode + i := 0 + path := paths[i] + for j := 0; j < len(cur.Children); j++ { + c := cur.Children[j] + if ("/" + path) == c.Path { + i++ + if i == ns { + return cur, i - 1 + } + path = paths[i] + cur = c + j = -1 + } + } + return cur, i +} + +type childrenRouterInfo []*RouterNode + +// Len is the number of elements in the collection. +func (c childrenRouterInfo) Len() int { + return len(c) +} + +// Less reports whether the element with +// index i should sort before the element with index j. +func (c childrenRouterInfo) Less(i, j int) bool { + ci := c[i].Path + if len(c[i].Children) != 0 { + ci = ci[1:] + } + cj := c[j].Path + if len(c[j].Children) != 0 { + cj = cj[1:] + } + return ci < cj +} + +// Swap swaps the elements with indexes i and j. +func (c childrenRouterInfo) Swap(i, j int) { + c[i], c[j] = c[j], c[i] +} + +var ( + regRegisterV3 = regexp.MustCompile(insertPointPatternNew) + regImport = regexp.MustCompile(`import \(\n`) +) + +func (pkgGen *HttpPackageGenerator) updateRegister(pkg, rDir, pkgName string, idlName string) error { + if pkgGen.tplsInfo[registerTplName].Disable { + return nil + } + register := RegisterInfo{ + PackageName: filepath.Base(rDir), + DepPkgAlias: strings.ReplaceAll(pkgName, "/", "_"), + DepPkg: pkg, + } + register.RegisterName = register.DepPkgAlias + ".Register" + idlName + "(r)" + registerPath := filepath.Join(rDir, registerTplName) + isExist, err := util.PathExist(registerPath) + if err != nil { + return err + } + if !isExist { + return pkgGen.TemplateGenerator.Generate(register, registerTplName, registerPath, false) + } + + file, err := ioutil.ReadFile(registerPath) + if err != nil { + return fmt.Errorf("read register '%s' failed, err: %v", registerPath, err.Error()) + } + + insertReg := register.RegisterName + if !bytes.Contains(file, []byte(insertReg)) { + t := !bytes.Contains(file, []byte(register.DepPkg)) + + if t { + file, err = util.AddImport(registerPath, register.DepPkgAlias, register.DepPkg) + if err != nil { + return err + } + } + + // + //if bytes.Contains(file, []byte(insertReg)) { + // return fmt.Errorf("the router(%s) has been registered", insertReg) + //} + + subIndexReg := regRegisterV3.FindSubmatchIndex(file) + if len(subIndexReg) != 2 || subIndexReg[0] < 1 { + return fmt.Errorf("wrong format %s: insert-point '%s' not found", string(file), insertPointPatternNew) + } + + bufReg := bytes.NewBuffer(nil) + bufReg.Write(file[:subIndexReg[1]]) + bufReg.WriteString("\n") + bufReg.WriteString(insertReg) + if t { + bufReg.WriteString("\n\t") + } + bufReg.Write(file[subIndexReg[1]:]) + + pkgGen.files = append(pkgGen.files, File{registerPath, string(bufReg.Bytes()), false, registerTplName}) + } + return nil +} + +func (pkgGen *HttpPackageGenerator) genRouter(pkg *HttpPackage, root *RouterNode, handlerPackage, routerDir, routerPackage string) error { + err := root.DyeGroupName(pkgGen.SnakeStyleMiddleware) + if err != nil { + return err + } + idleName := util.ToCamelCase(util.BaseNameAndTrim(pkg.IdlName)) + router := Router{ + FilePath: filepath.Join(routerDir, util.BaseNameAndTrim(pkg.IdlName)+".go"), + PackageName: filepath.Base(routerDir), + HandlerPackages: map[string]string{ + util.BaseName(handlerPackage, ""): handlerPackage, + }, + Router: root, + IdlName: idleName, + } + + if pkgGen.HandlerByMethod { + handlerMap := make(map[string]string, 1) + hook := func(layer int, node *RouterNode) error { + if len(node.HandlerPackage) != 0 { + handlerMap[node.HandlerPackageAlias] = node.HandlerPackage + } + return nil + } + root.DFS(0, hook) + router.HandlerPackages = handlerMap + } + + // store router info + pkg.RouterInfo = &router + + if !pkgGen.tplsInfo[routerTplName].Disable { + if err := pkgGen.TemplateGenerator.Generate(router, routerTplName, router.FilePath, false); err != nil { + return fmt.Errorf("generate router %s failed, err: %v", router.FilePath, err.Error()) + } + } + if err := pkgGen.updateMiddlewareReg(router, middlewareTplName, filepath.Join(routerDir, "middleware.go")); err != nil { + return fmt.Errorf("generate middleware %s failed, err: %v", filepath.Join(routerDir, "middleware.go"), err.Error()) + } + + if err := pkgGen.updateRegister(routerPackage, pkgGen.RouterDir, pkg.Package, idleName); err != nil { + return fmt.Errorf("update register for %s failed, err: %v", filepath.Join(routerDir, registerTplName), err.Error()) + } + return nil +} + +func (pkgGen *HttpPackageGenerator) updateMiddlewareReg(router interface{}, middlewareTpl, filePath string) error { + if pkgGen.tplsInfo[middlewareTpl].Disable { + return nil + } + isExist, err := util.PathExist(filePath) + if err != nil { + return err + } + if !isExist { + return pkgGen.TemplateGenerator.Generate(router, middlewareTpl, filePath, false) + } + var middlewareList []string + + _ = router.(Router).Router.DFS(0, func(layer int, node *RouterNode) error { + // non-leaf node will generate group middleware + if node.Children.Len() > 0 && len(node.GroupMiddleware) > 0 { + middlewareList = append(middlewareList, node.GroupMiddleware) + } + if len(node.HandlerMiddleware) > 0 { + middlewareList = append(middlewareList, node.HandlerMiddleware) + } + return nil + }) + + file, err := ioutil.ReadFile(filePath) + if err != nil { + return err + } + + for _, mw := range middlewareList { + mwNamePattern := fmt.Sprintf(" %sMw", mw) + if pkgGen.SnakeStyleMiddleware { + mwNamePattern = fmt.Sprintf(" %s_mw", mw) + } + if bytes.Contains(file, []byte(mwNamePattern)) { + continue + } + middlewareSingleTpl := pkgGen.tpls[middlewareSingleTplName] + if middlewareSingleTpl == nil { + return fmt.Errorf("tpl %s not found", middlewareSingleTplName) + } + data := make(map[string]string, 1) + data["MiddleWare"] = mw + middlewareFunc := bytes.NewBuffer(nil) + err = middlewareSingleTpl.Execute(middlewareFunc, data) + if err != nil { + return fmt.Errorf("execute template \"%s\" failed, %v", middlewareSingleTplName, err) + } + + buf := bytes.NewBuffer(nil) + _, err = buf.Write(file) + if err != nil { + return fmt.Errorf("write middleware \"%s\" failed, %v", mw, err) + } + _, err = buf.Write(middlewareFunc.Bytes()) + if err != nil { + return fmt.Errorf("write middleware \"%s\" failed, %v", mw, err) + } + file = buf.Bytes() + } + + pkgGen.files = append(pkgGen.files, File{filePath, string(file), false, middlewareTplName}) + + return nil +} + +// convertToMiddlewareName converts a route path to a middleware name +func convertToMiddlewareName(path string) string { + path = util.ToVarName([]string{path}) + path = strings.ToLower(path) + return path +} diff --git a/generator/template.go b/generator/template.go new file mode 100644 index 0000000..b339fbb --- /dev/null +++ b/generator/template.go @@ -0,0 +1,333 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package generator + +import ( + "bytes" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "text/template" + + "github.com/cloudwego/hertz/cmd/hz/meta" + "github.com/cloudwego/hertz/cmd/hz/util" + "github.com/cloudwego/hertz/cmd/hz/util/logs" +) + +var DefaultDelimiters = [2]string{"{{", "}}"} + +type TemplateConfig struct { + Layouts []Template `yaml:"layouts"` +} + +const ( + Skip = "skip" + Cover = "cover" + Append = "append" +) + +type Template struct { + Default bool // Is it the default template + Path string `yaml:"path"` // The generated path and its filename, such as biz/handler/ping.go + Delims [2]string `yaml:"delims"` // Template Action Instruction Identifier, default: "{{}}" + Body string `yaml:"body"` // Render template, currently only supports go template syntax + Disable bool `yaml:"disable"` // Disable generating file, used to disable default package template + LoopMethod bool `yaml:"loop_method"` // Loop generate files based on "method" + LoopService bool `yaml:"loop_service"` // Loop generate files based on "service" + UpdateBehavior UpdateBehavior `yaml:"update_behavior"` // Update command behavior; 0:unchanged, 1:regenerate, 2:append +} + +type UpdateBehavior struct { + Type string `yaml:"type"` // Update behavior type: skip/cover/append + // the following variables are used for append update + AppendKey string `yaml:"append_key"` // Append content based in key; for example: 'method'/'service' + InsertKey string `yaml:"insert_key"` // Insert content by "insert_key" + AppendTpl string `yaml:"append_content_tpl"` // Append content if UpdateBehavior is "append" + ImportTpl []string `yaml:"import_tpl"` // Import insert template + AppendLocation string `yaml:"append_location"` // AppendLocation specifies the location of append, the default is the end of the file +} + +// TemplateGenerator contains information about the output template +type TemplateGenerator struct { + OutputDir string + Config *TemplateConfig + Excludes []string + tpls map[string]*template.Template // "template name" -> "Template", it is used get the "parsed template" directly + tplsInfo map[string]*Template // "template name" -> "template info", it is used to get the original "template information" + dirs map[string]bool + isPackageTpl bool + + files []File + excludedFiles map[string]*File +} + +func (tg *TemplateGenerator) Init() error { + if tg.Config == nil { + return errors.New("config not set yet") + } + + if tg.tpls == nil { + tg.tpls = make(map[string]*template.Template, len(tg.Config.Layouts)) + } + if tg.tplsInfo == nil { + tg.tplsInfo = make(map[string]*Template, len(tg.Config.Layouts)) + } + if tg.dirs == nil { + tg.dirs = make(map[string]bool) + } + + for _, l := range tg.Config.Layouts { + if tg.isPackageTpl && IsDefaultPackageTpl(l.Path) { + continue + } + + // check if is a directory + var noFile bool + if strings.HasSuffix(l.Path, string(filepath.Separator)) { + noFile = true + } + path := l.Path + if filepath.IsAbs(path) { + return fmt.Errorf("absolute template path '%s' is not allowed", path) + } + dir := filepath.Dir(path) + isExist, err := util.PathExist(filepath.Join(tg.OutputDir, dir)) + if err != nil { + return fmt.Errorf("check directory '%s' failed, err: %v", dir, err.Error()) + } + if isExist { + tg.dirs[dir] = true + } else { + tg.dirs[dir] = false + } + + if noFile { + continue + } + + // parse templates + if _, ok := tg.tpls[path]; ok { + continue + } + err = tg.loadLayout(l, path, false) + if err != nil { + return err + } + } + + excludes := make(map[string]*File, len(tg.Excludes)) + for _, f := range tg.Excludes { + excludes[f] = &File{} + } + + tg.excludedFiles = excludes + return nil +} + +func (tg *TemplateGenerator) loadLayout(layout Template, tplName string, isDefaultTpl bool) error { + delims := DefaultDelimiters + if layout.Delims[0] != "" && layout.Delims[1] != "" { + delims = layout.Delims + } + // insert template funcs + tpl := template.New(tplName).Funcs(funcMap) + tpl = tpl.Delims(delims[0], delims[1]) + var err error + if tpl, err = tpl.Parse(layout.Body); err != nil { + return fmt.Errorf("parse template '%s' failed, err: %v", tplName, err.Error()) + } + layout.Default = isDefaultTpl + tg.tpls[tplName] = tpl + tg.tplsInfo[tplName] = &layout + return nil +} + +func (tg *TemplateGenerator) Generate(input interface{}, tplName, filepath string, noRepeat bool) error { + // check if "*" (global scope) data exists, and stores it to all + var all map[string]interface{} + if data, ok := input.(map[string]interface{}); ok { + ad, ok := data["*"] + if ok { + all = ad.(map[string]interface{}) + } + if all == nil { + all = map[string]interface{}{} + } + all["hzVersion"] = meta.Version + } + + file := bytes.NewBuffer(nil) + if tplName != "" { + tpl := tg.tpls[tplName] + if tpl == nil { + return fmt.Errorf("tpl %s not found", tplName) + } + if err := tpl.Execute(file, input); err != nil { + return fmt.Errorf("render template '%s' failed, err: %v", tplName, err.Error()) + } + + in := File{filepath, string(file.Bytes()), noRepeat, tplName} + tg.files = append(tg.files, in) + return nil + } + + for path, tpl := range tg.tpls { + file.Reset() + var fd interface{} + // search and merge rendering data + if data, ok := input.(map[string]interface{}); ok { + td := map[string]interface{}{} + tmp, ok := data[path] + if ok { + td = tmp.(map[string]interface{}) + } + for k, v := range all { + td[k] = v + } + fd = td + } else { + fd = input + } + if err := tpl.Execute(file, fd); err != nil { + return fmt.Errorf("render template '%s' failed, err: %v", path, err.Error()) + } + + in := File{path, string(file.Bytes()), noRepeat, tpl.Name()} + tg.files = append(tg.files, in) + } + + return nil +} + +func (tg *TemplateGenerator) Persist() error { + files := tg.files + outPath := tg.OutputDir + if !filepath.IsAbs(outPath) { + outPath, _ = filepath.Abs(outPath) + } + + for _, data := range files { + // check for -E flags + if _, ok := tg.excludedFiles[filepath.Join(data.Path)]; ok { + continue + } + + // lint file + if err := data.Lint(); err != nil { + return err + } + + // create rendered file + abPath := filepath.Join(outPath, data.Path) + abDir := filepath.Dir(abPath) + isExist, err := util.PathExist(abDir) + if err != nil { + return fmt.Errorf("check directory '%s' failed, err: %v", abDir, err.Error()) + } + if !isExist { + if err := os.MkdirAll(abDir, os.FileMode(0o744)); err != nil { + return fmt.Errorf("mkdir %s failed, err: %v", abDir, err.Error()) + } + } + + err = func() error { + file, err := os.OpenFile(abPath, os.O_CREATE|os.O_TRUNC|os.O_RDWR, os.FileMode(0o755)) + defer file.Close() + if err != nil { + return fmt.Errorf("open file '%s' failed, err: %v", abPath, err.Error()) + } + if _, err = file.WriteString(data.Content); err != nil { + return fmt.Errorf("write file '%s' failed, err: %v", abPath, err.Error()) + } + + return nil + }() + if err != nil { + return err + } + } + + tg.files = tg.files[:0] + return nil +} + +func (tg *TemplateGenerator) GetFormatAndExcludedFiles() ([]File, error) { + var files []File + outPath := tg.OutputDir + if !filepath.IsAbs(outPath) { + outPath, _ = filepath.Abs(outPath) + } + + for _, data := range tg.Files() { + if _, ok := tg.excludedFiles[filepath.Join(data.Path)]; ok { + continue + } + + // check repeat files + logs.Infof("Write %s", data.Path) + isExist, err := util.PathExist(filepath.Join(data.Path)) + if err != nil { + return nil, fmt.Errorf("check file '%s' failed, err: %v", data.Path, err.Error()) + } + if isExist && data.NoRepeat { + if data.FileTplName == handlerTplName { + logs.Warnf("Handler file(%s) has been generated.\n If you want to re-generate it, please copy and delete the file to prevent the already written code from being deleted.", data.Path) + } else if data.FileTplName == routerTplName { + logs.Warnf("Router file(%s) has been generated.\n If you want to re-generate it, please delete the file.", data.Path) + } else { + logs.Warnf("file '%s' already exists, so drop the generated file", data.Path) + } + continue + } + + // lint file + if err := data.Lint(); err != nil { + logs.Warnf("Lint file: %s failed:\n %s\n", data.Path, data.Content) + } + files = append(files, data) + } + + return files, nil +} + +func (tg *TemplateGenerator) Files() []File { + return tg.files +} + +func (tg *TemplateGenerator) Degenerate() error { + outPath := tg.OutputDir + if !filepath.IsAbs(outPath) { + outPath, _ = filepath.Abs(outPath) + } + for path := range tg.tpls { + abPath := filepath.Join(outPath, path) + if err := os.RemoveAll(abPath); err != nil { + return fmt.Errorf("remove file '%s' failed, err: %v", path, err.Error()) + } + } + for dir, exist := range tg.dirs { + if !exist { + abDir := filepath.Join(outPath, dir) + if err := os.RemoveAll(abDir); err != nil { + return fmt.Errorf("remove directory '%s' failed, err: %v", dir, err.Error()) + } + } + } + return nil +} diff --git a/generator/template_funcs.go b/generator/template_funcs.go new file mode 100644 index 0000000..d6a9837 --- /dev/null +++ b/generator/template_funcs.go @@ -0,0 +1,52 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package generator + +import ( + "strings" + "text/template" + + "github.com/Masterminds/sprig/v3" + "github.com/cloudwego/hertz/cmd/hz/util" +) + +var funcMap = func() template.FuncMap { + m := template.FuncMap{ + "GetUniqueHandlerOutDir": getUniqueHandlerOutDir, + "ToSnakeCase": util.ToSnakeCase, + "Split": strings.Split, + "Trim": strings.Trim, + "EqualFold": strings.EqualFold, + } + for key, f := range sprig.TxtFuncMap() { + m[key] = f + } + return m +}() + +// getUniqueHandlerOutDir uses to get unique "api.handler_path" +func getUniqueHandlerOutDir(methods []*HttpMethod) (ret []string) { + outDirMap := make(map[string]string) + for _, method := range methods { + if _, exist := outDirMap[method.OutputDir]; !exist { + outDirMap[method.OutputDir] = method.OutputDir + ret = append(ret, method.OutputDir) + } + } + + return ret +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..712c5a2 --- /dev/null +++ b/go.mod @@ -0,0 +1,14 @@ +module github.com/cloudwego/hertz/cmd/hz + +go 1.16 + +require ( + github.com/Masterminds/sprig/v3 v3.2.3 + github.com/cloudwego/thriftgo v0.1.7 + github.com/hashicorp/go-version v1.5.0 + github.com/jhump/protoreflect v1.12.0 + github.com/urfave/cli/v2 v2.23.0 + golang.org/x/tools v0.4.0 + google.golang.org/protobuf v1.28.0 + gopkg.in/yaml.v2 v2.4.0 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..ef08036 --- /dev/null +++ b/go.sum @@ -0,0 +1,189 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/BurntSushi/toml v1.1.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= +github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI= +github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU= +github.com/Masterminds/semver/v3 v3.2.0 h1:3MEsd0SM6jqZojhjLWWeBY+Kcjy9i6MQAeY7YgDP83g= +github.com/Masterminds/semver/v3 v3.2.0/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= +github.com/Masterminds/sprig/v3 v3.2.3 h1:eL2fZNezLomi0uOLqjQoN6BfsDD+fyLtgbJMAj9n6YA= +github.com/Masterminds/sprig/v3 v3.2.3/go.mod h1:rXcFaZ2zZbLRJv/xSysmlgIM1u11eBaRMhvYXJNkGuM= +github.com/apache/thrift v0.13.0 h1:5hryIiq9gtn+MiLVn0wP37kb/uTeRZgN08WoCsAhIhI= +github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cloudwego/thriftgo v0.1.7 h1:mTGRv6Dtwfp0hTPZXuIHwm3vtGOuZVTrWarI0xVzUYg= +github.com/cloudwego/thriftgo v0.1.7/go.mod h1:LzeafuLSiHA9JTiWC8TIMIq64iadeObgRUhmVG1OC/w= +github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= +github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= +github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/go-version v1.5.0 h1:O293SZ2Eg+AAYijkVK3jR786Am1bhDEh2GHT0tIVE5E= +github.com/hashicorp/go-version v1.5.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= +github.com/huandu/xstrings v1.3.3 h1:/Gcsuc1x8JVbJ9/rlye4xZnVAbEkGauT8lbebqcQws4= +github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= +github.com/imdario/mergo v0.3.11 h1:3tnifQM4i+fbajXKBHXWEH+KvNHqojZ778UH75j3bGA= +github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= +github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= +github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= +github.com/jhump/goprotoc v0.5.0/go.mod h1:VrbvcYrQOrTi3i0Vf+m+oqQWk9l72mjkJCYo7UvLHRQ= +github.com/jhump/protoreflect v1.11.0/go.mod h1:U7aMIjN0NWq9swDP7xDdoMfRHb35uiuTd3Z9nFXJf5E= +github.com/jhump/protoreflect v1.12.0 h1:1NQ4FpWMgn3by/n1X0fbeKEUxP1wBt7+Oitpv01HR10= +github.com/jhump/protoreflect v1.12.0/go.mod h1:JytZfP5d0r8pVNLZvai7U/MCuTWITgrI4tTg7puQFKI= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/mitchellh/copystructure v1.0.0 h1:Laisrj+bAB6b/yJwB5Bt3ITZhGJdqmxquMKeZ+mmkFQ= +github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw= +github.com/mitchellh/reflectwalk v1.0.0 h1:9D+8oIskB4VJBN5SFlmc27fSlIBZaov1Wpk/IfikLNY= +github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= +github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= +github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/urfave/cli/v2 v2.23.0 h1:pkly7gKIeYv3olPAeNajNpLjeJrmTPYCoZWaV+2VfvE= +github.com/urfave/cli/v2 v2.23.0/go.mod h1:1CNUng3PtjQMtRzJO4FMXBQvkGtuYRxxiR9xMa7jMwI= +github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= +github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.3.0 h1:a06MkbcxBrEFc0w0QIZWXrH/9cCX6KJyWbBOIwAn+7A= +golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.7.0 h1:LapD9S96VoQRhi/GrNTqeBJFrUjs5UHCAtTlgwA5oZA= +golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= +golang.org/x/net v0.3.0 h1:VWL6FNY2bEEmsGVKabSlHu5Irp34xmMRoqb/9lF9lxk= +golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= +golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= +golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.5.0 h1:OLmvp0KP+FVG99Ct/qFiL/Fhk4zp4QQnZ7b2U+5piUM= +golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.4.0 h1:7mTAgkunk3fr4GAloyyCasadO6h9zSsQZbwvcaIciV4= +golang.org/x/tools v0.4.0/go.mod h1:UE5sM2OK9E/d67R0ANs2xJizIymRP5gJU295PvKXxjQ= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.38.0 h1:/9BgsAsa5nWe26HqOlvlgJnqBuktYOLCgjCPqsa56W0= +google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= +google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/hz.exe b/hz.exe new file mode 100644 index 0000000..70a1df2 Binary files /dev/null and b/hz.exe differ diff --git a/main.go b/main.go new file mode 100644 index 0000000..403fac3 --- /dev/null +++ b/main.go @@ -0,0 +1,44 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "os" + + "github.com/cloudwego/hertz/cmd/hz/app" + "github.com/cloudwego/hertz/cmd/hz/util/logs" +) + +func main() { + // run in plugin mode + app.PluginMode() + + // run in normal mode + Run() +} + +func Run() { + defer func() { + logs.Flush() + }() + + cli := app.Init() + err := cli.Run(os.Args) + if err != nil { + logs.Errorf("%v\n", err) + } +} diff --git a/meta/const.go b/meta/const.go new file mode 100644 index 0000000..7441eab --- /dev/null +++ b/meta/const.go @@ -0,0 +1,92 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package meta + +import "runtime" + +// Version hz version +const Version = "v0.8.1" + +const DefaultServiceName = "hertz_service" + +// Mode hz run modes +type Mode int + +// SysType is the running program's operating system type +const SysType = runtime.GOOS + +const WindowsOS = "windows" + +const EnvPluginMode = "HERTZ_PLUGIN_MODE" + +// hz Commands +const ( + CmdUpdate = "update" + CmdNew = "new" + CmdModel = "model" + CmdClient = "client" +) + +// hz IDLs +const ( + IdlThrift = "thrift" + IdlProto = "proto" +) + +// Third-party Compilers +const ( + TpCompilerThrift = "thriftgo" + TpCompilerProto = "protoc" +) + +// hz Plugins +const ( + ProtocPluginName = "protoc-gen-hertz" + ThriftPluginName = "thrift-gen-hertz" +) + +// hz Errors +const ( + LoadError = 1 + GenerateLayoutError = 2 + PersistError = 3 + PluginError = 4 +) + +// Package Dir +const ( + ModelDir = "biz/model" + RouterDir = "biz/router" + HandlerDir = "biz/handler" +) + +// Backend Model Backends +type Backend string + +const ( + BackendGolang Backend = "golang" +) + +// template const value +const ( + SetBodyParam = "setBodyParam(req).\n" +) + +// TheUseOptionMessage indicates that the generating of 'model code' is aborted due to the -use option for thrift IDL. +const TheUseOptionMessage = "'model code' is not generated due to the '-use' option" + +const AddThriftReplace = "do not generate 'go.mod', please add 'replace github.com/apache/thrift => github.com/apache/thrift v0.13.0' to your 'go.mod'" diff --git a/meta/manifest.go b/meta/manifest.go new file mode 100644 index 0000000..0c6ecd0 --- /dev/null +++ b/meta/manifest.go @@ -0,0 +1,96 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package meta + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + + gv "github.com/hashicorp/go-version" + "gopkg.in/yaml.v2" +) + +const ManifestFile = ".hz" + +type Manifest struct { + Version string `yaml:"hz version"` + HandlerDir string `yaml:"handlerDir"` + ModelDir string `yaml:"modelDir"` + RouterDir string `yaml:"routerDir"` +} + +var GoVersion *gv.Version + +func init() { + // valid by unit test already, so no need to check error + GoVersion, _ = gv.NewVersion(Version) +} + +func (manifest *Manifest) InitAndValidate(dir string) error { + m, err := loadConfigFile(filepath.Join(dir, ManifestFile)) + if err != nil { + return fmt.Errorf("can not load \".hz\", err: %v", err) + } + + if len(m.Version) == 0 { + return fmt.Errorf("can not get hz version form \".hz\", current project doesn't belong to hertz framework") + } + + *manifest = *m + _, err = gv.NewVersion(manifest.Version) + if err != nil { + return fmt.Errorf("invalid hz version in \".hz\", err: %v", err) + } + + return nil +} + +const hzTitle = "// Code generated by hz. DO NOT EDIT." + +func (manifest *Manifest) String() string { + conf, _ := yaml.Marshal(*manifest) + + return hzTitle + "\n\n" + + string(conf) +} + +func (manifest *Manifest) Persist(dir string) error { + file := filepath.Join(dir, ManifestFile) + fd, err := os.OpenFile(file, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.FileMode(0o644)) + if err != nil { + return err + } + defer fd.Close() + _, err = fd.WriteString(manifest.String()) + return err +} + +// loadConfigFile load config file from path +func loadConfigFile(path string) (*Manifest, error) { + file, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var manifest Manifest + file = bytes.TrimPrefix(file, []byte(hzTitle)) + if err = yaml.Unmarshal(file, &manifest); err != nil { + return nil, fmt.Errorf("decode \".hz\" failed, err: %v", err) + } + return &manifest, nil +} diff --git a/meta/manifest_test.go b/meta/manifest_test.go new file mode 100644 index 0000000..35b99bb --- /dev/null +++ b/meta/manifest_test.go @@ -0,0 +1,30 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package meta + +import ( + "testing" + + gv "github.com/hashicorp/go-version" +) + +func TestValidate(t *testing.T) { + _, err := gv.NewVersion(Version) + if err != nil { + t.Fatalf("not a valid version: %s", err) + } +} diff --git a/protobuf/api/api.pb.go b/protobuf/api/api.pb.go new file mode 100644 index 0000000..4c1d8c6 --- /dev/null +++ b/protobuf/api/api.pb.go @@ -0,0 +1,679 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.30.0 +// protoc v3.21.12 +// source: api.proto + +package api + +import ( + reflect "reflect" + + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + descriptorpb "google.golang.org/protobuf/types/descriptorpb" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +var file_api_proto_extTypes = []protoimpl.ExtensionInfo{ + { + ExtendedType: (*descriptorpb.FieldOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50101, + Name: "api.raw_body", + Tag: "bytes,50101,opt,name=raw_body", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.FieldOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50102, + Name: "api.query", + Tag: "bytes,50102,opt,name=query", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.FieldOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50103, + Name: "api.header", + Tag: "bytes,50103,opt,name=header", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.FieldOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50104, + Name: "api.cookie", + Tag: "bytes,50104,opt,name=cookie", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.FieldOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50105, + Name: "api.body", + Tag: "bytes,50105,opt,name=body", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.FieldOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50106, + Name: "api.path", + Tag: "bytes,50106,opt,name=path", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.FieldOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50107, + Name: "api.vd", + Tag: "bytes,50107,opt,name=vd", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.FieldOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50108, + Name: "api.form", + Tag: "bytes,50108,opt,name=form", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.FieldOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50109, + Name: "api.js_conv", + Tag: "bytes,50109,opt,name=js_conv", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.FieldOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50110, + Name: "api.file_name", + Tag: "bytes,50110,opt,name=file_name", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.FieldOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50111, + Name: "api.none", + Tag: "bytes,50111,opt,name=none", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.FieldOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50131, + Name: "api.form_compatible", + Tag: "bytes,50131,opt,name=form_compatible", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.FieldOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50132, + Name: "api.js_conv_compatible", + Tag: "bytes,50132,opt,name=js_conv_compatible", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.FieldOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50133, + Name: "api.file_name_compatible", + Tag: "bytes,50133,opt,name=file_name_compatible", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.FieldOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50134, + Name: "api.none_compatible", + Tag: "bytes,50134,opt,name=none_compatible", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.FieldOptions)(nil), + ExtensionType: (*string)(nil), + Field: 51001, + Name: "api.go_tag", + Tag: "bytes,51001,opt,name=go_tag", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50201, + Name: "api.get", + Tag: "bytes,50201,opt,name=get", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50202, + Name: "api.post", + Tag: "bytes,50202,opt,name=post", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50203, + Name: "api.put", + Tag: "bytes,50203,opt,name=put", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50204, + Name: "api.delete", + Tag: "bytes,50204,opt,name=delete", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50205, + Name: "api.patch", + Tag: "bytes,50205,opt,name=patch", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50206, + Name: "api.options", + Tag: "bytes,50206,opt,name=options", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50207, + Name: "api.head", + Tag: "bytes,50207,opt,name=head", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50208, + Name: "api.any", + Tag: "bytes,50208,opt,name=any", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50301, + Name: "api.gen_path", + Tag: "bytes,50301,opt,name=gen_path", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50302, + Name: "api.api_version", + Tag: "bytes,50302,opt,name=api_version", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50303, + Name: "api.tag", + Tag: "bytes,50303,opt,name=tag", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50304, + Name: "api.name", + Tag: "bytes,50304,opt,name=name", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50305, + Name: "api.api_level", + Tag: "bytes,50305,opt,name=api_level", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50306, + Name: "api.serializer", + Tag: "bytes,50306,opt,name=serializer", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50307, + Name: "api.param", + Tag: "bytes,50307,opt,name=param", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50308, + Name: "api.baseurl", + Tag: "bytes,50308,opt,name=baseurl", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50309, + Name: "api.handler_path", + Tag: "bytes,50309,opt,name=handler_path", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50331, + Name: "api.handler_path_compatible", + Tag: "bytes,50331,opt,name=handler_path_compatible", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.EnumValueOptions)(nil), + ExtensionType: (*int32)(nil), + Field: 50401, + Name: "api.http_code", + Tag: "varint,50401,opt,name=http_code", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.ServiceOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50402, + Name: "api.base_domain", + Tag: "bytes,50402,opt,name=base_domain", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.ServiceOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50731, + Name: "api.base_domain_compatible", + Tag: "bytes,50731,opt,name=base_domain_compatible", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.ServiceOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50732, + Name: "api.service_path", + Tag: "bytes,50732,opt,name=service_path", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.MessageOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50830, + Name: "api.reserve", + Tag: "bytes,50830,opt,name=reserve", + Filename: "api.proto", + }, +} + +// Extension fields to descriptorpb.FieldOptions. +var ( + // optional string raw_body = 50101; + E_RawBody = &file_api_proto_extTypes[0] + // optional string query = 50102; + E_Query = &file_api_proto_extTypes[1] + // optional string header = 50103; + E_Header = &file_api_proto_extTypes[2] + // optional string cookie = 50104; + E_Cookie = &file_api_proto_extTypes[3] + // optional string body = 50105; + E_Body = &file_api_proto_extTypes[4] + // optional string path = 50106; + E_Path = &file_api_proto_extTypes[5] + // optional string vd = 50107; + E_Vd = &file_api_proto_extTypes[6] + // optional string form = 50108; + E_Form = &file_api_proto_extTypes[7] + // optional string js_conv = 50109; + E_JsConv = &file_api_proto_extTypes[8] + // optional string file_name = 50110; + E_FileName = &file_api_proto_extTypes[9] + // optional string none = 50111; + E_None = &file_api_proto_extTypes[10] + // 50131~50160 used to extend field option by hz + // + // optional string form_compatible = 50131; + E_FormCompatible = &file_api_proto_extTypes[11] + // optional string js_conv_compatible = 50132; + E_JsConvCompatible = &file_api_proto_extTypes[12] + // optional string file_name_compatible = 50133; + E_FileNameCompatible = &file_api_proto_extTypes[13] + // optional string none_compatible = 50134; + E_NoneCompatible = &file_api_proto_extTypes[14] + // optional string go_tag = 51001; + E_GoTag = &file_api_proto_extTypes[15] +) + +// Extension fields to descriptorpb.MethodOptions. +var ( + // optional string get = 50201; + E_Get = &file_api_proto_extTypes[16] + // optional string post = 50202; + E_Post = &file_api_proto_extTypes[17] + // optional string put = 50203; + E_Put = &file_api_proto_extTypes[18] + // optional string delete = 50204; + E_Delete = &file_api_proto_extTypes[19] + // optional string patch = 50205; + E_Patch = &file_api_proto_extTypes[20] + // optional string options = 50206; + E_Options = &file_api_proto_extTypes[21] + // optional string head = 50207; + E_Head = &file_api_proto_extTypes[22] + // optional string any = 50208; + E_Any = &file_api_proto_extTypes[23] + // optional string gen_path = 50301; + E_GenPath = &file_api_proto_extTypes[24] // The path specified by the user when the client code is generated, with a higher priority than api_version + // optional string api_version = 50302; + E_ApiVersion = &file_api_proto_extTypes[25] // Specify the value of the :version variable in path when the client code is generated + // optional string tag = 50303; + E_Tag = &file_api_proto_extTypes[26] // rpc tag, can be multiple, separated by commas + // optional string name = 50304; + E_Name = &file_api_proto_extTypes[27] // Name of rpc + // optional string api_level = 50305; + E_ApiLevel = &file_api_proto_extTypes[28] // Interface Level + // optional string serializer = 50306; + E_Serializer = &file_api_proto_extTypes[29] // Serialization method + // optional string param = 50307; + E_Param = &file_api_proto_extTypes[30] // Whether client requests take public parameters + // optional string baseurl = 50308; + E_Baseurl = &file_api_proto_extTypes[31] // Baseurl used in ttnet routing + // optional string handler_path = 50309; + E_HandlerPath = &file_api_proto_extTypes[32] // handler_path specifies the path to generate the method + // 50331~50360 used to extend method option by hz + // + // optional string handler_path_compatible = 50331; + E_HandlerPathCompatible = &file_api_proto_extTypes[33] // handler_path specifies the path to generate the method +) + +// Extension fields to descriptorpb.EnumValueOptions. +var ( + // optional int32 http_code = 50401; + E_HttpCode = &file_api_proto_extTypes[34] +) + +// Extension fields to descriptorpb.ServiceOptions. +var ( + // optional string base_domain = 50402; + E_BaseDomain = &file_api_proto_extTypes[35] + // 50731~50760 used to extend service option by hz + // + // optional string base_domain_compatible = 50731; + E_BaseDomainCompatible = &file_api_proto_extTypes[36] + // optional string service_path = 50732; + E_ServicePath = &file_api_proto_extTypes[37] +) + +// Extension fields to descriptorpb.MessageOptions. +var ( + // optional string reserve = 50830; + E_Reserve = &file_api_proto_extTypes[38] +) + +var File_api_proto protoreflect.FileDescriptor + +var file_api_proto_rawDesc = []byte{ + 0x0a, 0x09, 0x61, 0x70, 0x69, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x03, 0x61, 0x70, 0x69, + 0x1a, 0x20, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, + 0x66, 0x2f, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x6f, 0x72, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x3a, 0x3a, 0x0a, 0x08, 0x72, 0x61, 0x77, 0x5f, 0x62, 0x6f, 0x64, 0x79, 0x12, 0x1d, + 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, + 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xb5, 0x87, + 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x72, 0x61, 0x77, 0x42, 0x6f, 0x64, 0x79, 0x3a, 0x35, + 0x0a, 0x05, 0x71, 0x75, 0x65, 0x72, 0x79, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, + 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xb6, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, + 0x71, 0x75, 0x65, 0x72, 0x79, 0x3a, 0x37, 0x0a, 0x06, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x12, + 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, + 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xb7, + 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x3a, 0x37, + 0x0a, 0x06, 0x63, 0x6f, 0x6f, 0x6b, 0x69, 0x65, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, + 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, + 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xb8, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x06, 0x63, 0x6f, 0x6f, 0x6b, 0x69, 0x65, 0x3a, 0x33, 0x0a, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x12, + 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, + 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xb9, + 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x3a, 0x33, 0x0a, 0x04, + 0x70, 0x61, 0x74, 0x68, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, + 0x6f, 0x6e, 0x73, 0x18, 0xba, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, + 0x68, 0x3a, 0x2f, 0x0a, 0x02, 0x76, 0x64, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, + 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xbb, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, + 0x76, 0x64, 0x3a, 0x33, 0x0a, 0x04, 0x66, 0x6f, 0x72, 0x6d, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, + 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, + 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xbc, 0x87, 0x03, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x04, 0x66, 0x6f, 0x72, 0x6d, 0x3a, 0x38, 0x0a, 0x07, 0x6a, 0x73, 0x5f, 0x63, 0x6f, + 0x6e, 0x76, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, + 0x73, 0x18, 0xbd, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6a, 0x73, 0x43, 0x6f, 0x6e, + 0x76, 0x3a, 0x3c, 0x0a, 0x09, 0x66, 0x69, 0x6c, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x1d, + 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, + 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xbe, 0x87, + 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x66, 0x69, 0x6c, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x3a, + 0x33, 0x0a, 0x04, 0x6e, 0x6f, 0x6e, 0x65, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, + 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xbf, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x6e, 0x6f, 0x6e, 0x65, 0x3a, 0x48, 0x0a, 0x0f, 0x66, 0x6f, 0x72, 0x6d, 0x5f, 0x63, 0x6f, 0x6d, + 0x70, 0x61, 0x74, 0x69, 0x62, 0x6c, 0x65, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, + 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xd3, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, + 0x66, 0x6f, 0x72, 0x6d, 0x43, 0x6f, 0x6d, 0x70, 0x61, 0x74, 0x69, 0x62, 0x6c, 0x65, 0x3a, 0x4d, + 0x0a, 0x12, 0x6a, 0x73, 0x5f, 0x63, 0x6f, 0x6e, 0x76, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x61, 0x74, + 0x69, 0x62, 0x6c, 0x65, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, + 0x6f, 0x6e, 0x73, 0x18, 0xd4, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x10, 0x6a, 0x73, 0x43, + 0x6f, 0x6e, 0x76, 0x43, 0x6f, 0x6d, 0x70, 0x61, 0x74, 0x69, 0x62, 0x6c, 0x65, 0x3a, 0x51, 0x0a, + 0x14, 0x66, 0x69, 0x6c, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x61, + 0x74, 0x69, 0x62, 0x6c, 0x65, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xd5, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x66, 0x69, + 0x6c, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x43, 0x6f, 0x6d, 0x70, 0x61, 0x74, 0x69, 0x62, 0x6c, 0x65, + 0x3a, 0x48, 0x0a, 0x0f, 0x6e, 0x6f, 0x6e, 0x65, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x61, 0x74, 0x69, + 0x62, 0x6c, 0x65, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, + 0x6e, 0x73, 0x18, 0xd6, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x6e, 0x6f, 0x6e, 0x65, + 0x43, 0x6f, 0x6d, 0x70, 0x61, 0x74, 0x69, 0x62, 0x6c, 0x65, 0x3a, 0x36, 0x0a, 0x06, 0x67, 0x6f, + 0x5f, 0x74, 0x61, 0x67, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, + 0x6f, 0x6e, 0x73, 0x18, 0xb9, 0x8e, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x67, 0x6f, 0x54, + 0x61, 0x67, 0x3a, 0x32, 0x0a, 0x03, 0x67, 0x65, 0x74, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, + 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, + 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x99, 0x88, 0x03, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x03, 0x67, 0x65, 0x74, 0x3a, 0x34, 0x0a, 0x04, 0x70, 0x6f, 0x73, 0x74, 0x12, 0x1e, + 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, + 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x9a, + 0x88, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x6f, 0x73, 0x74, 0x3a, 0x32, 0x0a, 0x03, + 0x70, 0x75, 0x74, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, + 0x6f, 0x6e, 0x73, 0x18, 0x9b, 0x88, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x70, 0x75, 0x74, + 0x3a, 0x38, 0x0a, 0x06, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, + 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, + 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x9c, 0x88, 0x03, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x06, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x3a, 0x36, 0x0a, 0x05, 0x70, 0x61, + 0x74, 0x63, 0x68, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, + 0x6f, 0x6e, 0x73, 0x18, 0x9d, 0x88, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x70, 0x61, 0x74, + 0x63, 0x68, 0x3a, 0x3a, 0x0a, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x1e, 0x2e, + 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, + 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x9e, 0x88, + 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x3a, 0x34, + 0x0a, 0x04, 0x68, 0x65, 0x61, 0x64, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, + 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x9f, 0x88, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x68, 0x65, 0x61, 0x64, 0x3a, 0x32, 0x0a, 0x03, 0x61, 0x6e, 0x79, 0x12, 0x1e, 0x2e, 0x67, 0x6f, + 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, + 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xa0, 0x88, 0x03, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x03, 0x61, 0x6e, 0x79, 0x3a, 0x3b, 0x0a, 0x08, 0x67, 0x65, 0x6e, 0x5f, + 0x70, 0x61, 0x74, 0x68, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xfd, 0x88, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x67, 0x65, + 0x6e, 0x50, 0x61, 0x74, 0x68, 0x3a, 0x41, 0x0a, 0x0b, 0x61, 0x70, 0x69, 0x5f, 0x76, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xfe, 0x88, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x70, + 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x3a, 0x32, 0x0a, 0x03, 0x74, 0x61, 0x67, 0x12, + 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, + 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, + 0xff, 0x88, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x74, 0x61, 0x67, 0x3a, 0x34, 0x0a, 0x04, + 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x80, 0x89, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, + 0x6d, 0x65, 0x3a, 0x3d, 0x0a, 0x09, 0x61, 0x70, 0x69, 0x5f, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x12, + 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, + 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, + 0x81, 0x89, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x70, 0x69, 0x4c, 0x65, 0x76, 0x65, + 0x6c, 0x3a, 0x40, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x72, 0x12, + 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, + 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, + 0x82, 0x89, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x73, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x69, + 0x7a, 0x65, 0x72, 0x3a, 0x36, 0x0a, 0x05, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x12, 0x1e, 0x2e, 0x67, + 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, + 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x83, 0x89, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x3a, 0x3a, 0x0a, 0x07, 0x62, + 0x61, 0x73, 0x65, 0x75, 0x72, 0x6c, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, + 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x84, 0x89, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, + 0x62, 0x61, 0x73, 0x65, 0x75, 0x72, 0x6c, 0x3a, 0x43, 0x0a, 0x0c, 0x68, 0x61, 0x6e, 0x64, 0x6c, + 0x65, 0x72, 0x5f, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, + 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x85, 0x89, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0b, 0x68, 0x61, 0x6e, 0x64, 0x6c, 0x65, 0x72, 0x50, 0x61, 0x74, 0x68, 0x3a, 0x58, 0x0a, 0x17, + 0x68, 0x61, 0x6e, 0x64, 0x6c, 0x65, 0x72, 0x5f, 0x70, 0x61, 0x74, 0x68, 0x5f, 0x63, 0x6f, 0x6d, + 0x70, 0x61, 0x74, 0x69, 0x62, 0x6c, 0x65, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, + 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x9b, 0x89, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x15, 0x68, 0x61, 0x6e, 0x64, 0x6c, 0x65, 0x72, 0x50, 0x61, 0x74, 0x68, 0x43, 0x6f, 0x6d, 0x70, + 0x61, 0x74, 0x69, 0x62, 0x6c, 0x65, 0x3a, 0x40, 0x0a, 0x09, 0x68, 0x74, 0x74, 0x70, 0x5f, 0x63, + 0x6f, 0x64, 0x65, 0x12, 0x21, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6e, 0x75, 0x6d, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x4f, + 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xe1, 0x89, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x08, + 0x68, 0x74, 0x74, 0x70, 0x43, 0x6f, 0x64, 0x65, 0x3a, 0x42, 0x0a, 0x0b, 0x62, 0x61, 0x73, 0x65, + 0x5f, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1f, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, + 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xe2, 0x89, 0x03, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0a, 0x62, 0x61, 0x73, 0x65, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x3a, 0x57, 0x0a, 0x16, + 0x62, 0x61, 0x73, 0x65, 0x5f, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x70, + 0x61, 0x74, 0x69, 0x62, 0x6c, 0x65, 0x12, 0x1f, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, + 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xab, 0x8c, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x14, 0x62, 0x61, 0x73, 0x65, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x43, 0x6f, 0x6d, 0x70, 0x61, + 0x74, 0x69, 0x62, 0x6c, 0x65, 0x3a, 0x44, 0x0a, 0x0c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, + 0x5f, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1f, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4f, + 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xac, 0x8c, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, + 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x50, 0x61, 0x74, 0x68, 0x3a, 0x3b, 0x0a, 0x07, 0x72, + 0x65, 0x73, 0x65, 0x72, 0x76, 0x65, 0x12, 0x1f, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x8e, 0x8d, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x07, 0x72, 0x65, 0x73, 0x65, 0x72, 0x76, 0x65, 0x42, 0x06, 0x5a, 0x04, 0x2f, 0x61, 0x70, 0x69, +} + +var file_api_proto_goTypes = []interface{}{ + (*descriptorpb.FieldOptions)(nil), // 0: google.protobuf.FieldOptions + (*descriptorpb.MethodOptions)(nil), // 1: google.protobuf.MethodOptions + (*descriptorpb.EnumValueOptions)(nil), // 2: google.protobuf.EnumValueOptions + (*descriptorpb.ServiceOptions)(nil), // 3: google.protobuf.ServiceOptions + (*descriptorpb.MessageOptions)(nil), // 4: google.protobuf.MessageOptions +} +var file_api_proto_depIdxs = []int32{ + 0, // 0: api.raw_body:extendee -> google.protobuf.FieldOptions + 0, // 1: api.query:extendee -> google.protobuf.FieldOptions + 0, // 2: api.header:extendee -> google.protobuf.FieldOptions + 0, // 3: api.cookie:extendee -> google.protobuf.FieldOptions + 0, // 4: api.body:extendee -> google.protobuf.FieldOptions + 0, // 5: api.path:extendee -> google.protobuf.FieldOptions + 0, // 6: api.vd:extendee -> google.protobuf.FieldOptions + 0, // 7: api.form:extendee -> google.protobuf.FieldOptions + 0, // 8: api.js_conv:extendee -> google.protobuf.FieldOptions + 0, // 9: api.file_name:extendee -> google.protobuf.FieldOptions + 0, // 10: api.none:extendee -> google.protobuf.FieldOptions + 0, // 11: api.form_compatible:extendee -> google.protobuf.FieldOptions + 0, // 12: api.js_conv_compatible:extendee -> google.protobuf.FieldOptions + 0, // 13: api.file_name_compatible:extendee -> google.protobuf.FieldOptions + 0, // 14: api.none_compatible:extendee -> google.protobuf.FieldOptions + 0, // 15: api.go_tag:extendee -> google.protobuf.FieldOptions + 1, // 16: api.get:extendee -> google.protobuf.MethodOptions + 1, // 17: api.post:extendee -> google.protobuf.MethodOptions + 1, // 18: api.put:extendee -> google.protobuf.MethodOptions + 1, // 19: api.delete:extendee -> google.protobuf.MethodOptions + 1, // 20: api.patch:extendee -> google.protobuf.MethodOptions + 1, // 21: api.options:extendee -> google.protobuf.MethodOptions + 1, // 22: api.head:extendee -> google.protobuf.MethodOptions + 1, // 23: api.any:extendee -> google.protobuf.MethodOptions + 1, // 24: api.gen_path:extendee -> google.protobuf.MethodOptions + 1, // 25: api.api_version:extendee -> google.protobuf.MethodOptions + 1, // 26: api.tag:extendee -> google.protobuf.MethodOptions + 1, // 27: api.name:extendee -> google.protobuf.MethodOptions + 1, // 28: api.api_level:extendee -> google.protobuf.MethodOptions + 1, // 29: api.serializer:extendee -> google.protobuf.MethodOptions + 1, // 30: api.param:extendee -> google.protobuf.MethodOptions + 1, // 31: api.baseurl:extendee -> google.protobuf.MethodOptions + 1, // 32: api.handler_path:extendee -> google.protobuf.MethodOptions + 1, // 33: api.handler_path_compatible:extendee -> google.protobuf.MethodOptions + 2, // 34: api.http_code:extendee -> google.protobuf.EnumValueOptions + 3, // 35: api.base_domain:extendee -> google.protobuf.ServiceOptions + 3, // 36: api.base_domain_compatible:extendee -> google.protobuf.ServiceOptions + 3, // 37: api.service_path:extendee -> google.protobuf.ServiceOptions + 4, // 38: api.reserve:extendee -> google.protobuf.MessageOptions + 39, // [39:39] is the sub-list for method output_type + 39, // [39:39] is the sub-list for method input_type + 39, // [39:39] is the sub-list for extension type_name + 0, // [0:39] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_api_proto_init() } +func file_api_proto_init() { + if File_api_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_api_proto_rawDesc, + NumEnums: 0, + NumMessages: 0, + NumExtensions: 39, + NumServices: 0, + }, + GoTypes: file_api_proto_goTypes, + DependencyIndexes: file_api_proto_depIdxs, + ExtensionInfos: file_api_proto_extTypes, + }.Build() + File_api_proto = out.File + file_api_proto_rawDesc = nil + file_api_proto_goTypes = nil + file_api_proto_depIdxs = nil +} diff --git a/protobuf/api/api.proto b/protobuf/api/api.proto new file mode 100644 index 0000000..71a3cd3 --- /dev/null +++ b/protobuf/api/api.proto @@ -0,0 +1,76 @@ +syntax = "proto2"; + +package api; + +import "google/protobuf/descriptor.proto"; + +option go_package = "/api"; + +extend google.protobuf.FieldOptions { + optional string raw_body = 50101; + optional string query = 50102; + optional string header = 50103; + optional string cookie = 50104; + optional string body = 50105; + optional string path = 50106; + optional string vd = 50107; + optional string form = 50108; + optional string js_conv = 50109; + optional string file_name = 50110; + optional string none = 50111; + + // 50131~50160 used to extend field option by hz + optional string form_compatible = 50131; + optional string js_conv_compatible = 50132; + optional string file_name_compatible = 50133; + optional string none_compatible = 50134; + // 50135 is reserved to vt_compatible + // optional FieldRules vt_compatible = 50135; + + optional string go_tag = 51001; +} + +extend google.protobuf.MethodOptions { + optional string get = 50201; + optional string post = 50202; + optional string put = 50203; + optional string delete = 50204; + optional string patch = 50205; + optional string options = 50206; + optional string head = 50207; + optional string any = 50208; + optional string gen_path = 50301; // The path specified by the user when the client code is generated, with a higher priority than api_version + optional string api_version = 50302; // Specify the value of the :version variable in path when the client code is generated + optional string tag = 50303; // rpc tag, can be multiple, separated by commas + optional string name = 50304; // Name of rpc + optional string api_level = 50305; // Interface Level + optional string serializer = 50306; // Serialization method + optional string param = 50307; // Whether client requests take public parameters + optional string baseurl = 50308; // Baseurl used in ttnet routing + optional string handler_path = 50309; // handler_path specifies the path to generate the method + + // 50331~50360 used to extend method option by hz + optional string handler_path_compatible = 50331; // handler_path specifies the path to generate the method +} + +extend google.protobuf.EnumValueOptions { + optional int32 http_code = 50401; + + // 50431~50460 used to extend enum option by hz +} + +extend google.protobuf.ServiceOptions { + optional string base_domain = 50402; + + // 50731~50760 used to extend service option by hz + optional string base_domain_compatible = 50731; + optional string service_path = 50732; +} + +extend google.protobuf.MessageOptions { + // optional FieldRules msg_vt = 50111; + + optional string reserve = 50830; + // 550831 is reserved to msg_vt_compatible + // optional FieldRules msg_vt_compatible = 50831; +} \ No newline at end of file diff --git a/protobuf/ast.go b/protobuf/ast.go new file mode 100644 index 0000000..d607a83 --- /dev/null +++ b/protobuf/ast.go @@ -0,0 +1,760 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package protobuf + +import ( + "fmt" + "path/filepath" + "sort" + "strings" + + "github.com/cloudwego/hertz/cmd/hz/generator" + "github.com/cloudwego/hertz/cmd/hz/generator/model" + "github.com/cloudwego/hertz/cmd/hz/meta" + "github.com/cloudwego/hertz/cmd/hz/protobuf/api" + "github.com/cloudwego/hertz/cmd/hz/util" + "github.com/cloudwego/hertz/cmd/hz/util/logs" + "github.com/jhump/protoreflect/desc" + "google.golang.org/protobuf/compiler/protogen" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/runtime/protoimpl" + "google.golang.org/protobuf/types/descriptorpb" +) + +var BaseProto = descriptorpb.FileDescriptorProto{} + +// getGoPackage get option go_package +// If pkgMap is specified, the specified value is used as the go_package; +// If go package is not specified, then the value of package is used as go_package. +func getGoPackage(f *descriptorpb.FileDescriptorProto, pkgMap map[string]string) string { + if f.Options == nil { + f.Options = new(descriptorpb.FileOptions) + } + if f.Options.GoPackage == nil { + f.Options.GoPackage = new(string) + } + goPkg := f.Options.GetGoPackage() + + // if go_package has ";", for example go_package="/a/b/c;d", we will use "/a/b/c" as go_package + if strings.Contains(goPkg, ";") { + pkg := strings.Split(goPkg, ";") + if len(pkg) == 2 { + logs.Warnf("The go_package of the file(%s) is \"%s\", hz will use \"%s\" as the go_package.", f.GetName(), goPkg, pkg[0]) + goPkg = pkg[0] + } + + } + + if goPkg == "" { + goPkg = f.GetPackage() + } + if opt, ok := pkgMap[f.GetName()]; ok { + return opt + } + return goPkg +} + +func switchBaseType(typ descriptorpb.FieldDescriptorProto_Type) *model.Type { + switch typ { + case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE, descriptorpb.FieldDescriptorProto_TYPE_GROUP: + return nil + case descriptorpb.FieldDescriptorProto_TYPE_INT64: + return model.TypeInt64 + case descriptorpb.FieldDescriptorProto_TYPE_INT32: + return model.TypeInt32 + case descriptorpb.FieldDescriptorProto_TYPE_UINT64: + return model.TypeUint64 + case descriptorpb.FieldDescriptorProto_TYPE_UINT32: + return model.TypeUint32 + case descriptorpb.FieldDescriptorProto_TYPE_FIXED64: + return model.TypeUint64 + case descriptorpb.FieldDescriptorProto_TYPE_FIXED32: + return model.TypeUint32 + case descriptorpb.FieldDescriptorProto_TYPE_BOOL: + return model.TypeBool + case descriptorpb.FieldDescriptorProto_TYPE_STRING: + return model.TypeString + case descriptorpb.FieldDescriptorProto_TYPE_BYTES: + return model.TypeBinary + case descriptorpb.FieldDescriptorProto_TYPE_SFIXED32: + return model.TypeInt32 + case descriptorpb.FieldDescriptorProto_TYPE_SFIXED64: + return model.TypeInt64 + case descriptorpb.FieldDescriptorProto_TYPE_SINT32: + return model.TypeInt32 + case descriptorpb.FieldDescriptorProto_TYPE_SINT64: + return model.TypeInt64 + case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE: + return model.TypeFloat64 + case descriptorpb.FieldDescriptorProto_TYPE_FLOAT: + return model.TypeFloat32 + } + return nil +} + +func astToService(ast *descriptorpb.FileDescriptorProto, resolver *Resolver, cmdType string, gen *protogen.Plugin) ([]*generator.Service, error) { + resolver.ExportReferred(true, false) + ss := ast.GetService() + out := make([]*generator.Service, 0, len(ss)) + var merges model.Models + + for _, s := range ss { + service := &generator.Service{ + Name: s.GetName(), + } + + service.BaseDomain = "" + domainAnno := getCompatibleAnnotation(s.GetOptions(), api.E_BaseDomain, api.E_BaseDomainCompatible) + if cmdType == meta.CmdClient { + val, ok := domainAnno.(string) + if ok && len(val) != 0 { + service.BaseDomain = val + } + } + + ms := s.GetMethod() + methods := make([]*generator.HttpMethod, 0, len(ms)) + clientMethods := make([]*generator.ClientMethod, 0, len(ms)) + servicePathAnno := checkFirstOption(api.E_ServicePath, s.GetOptions()) + servicePath := "" + if val, ok := servicePathAnno.(string); ok { + servicePath = val + } + for _, m := range ms { + rs := getAllOptions(HttpMethodOptions, m.GetOptions()) + if len(rs) == 0 { + continue + } + httpOpts := httpOptions{} + for k, v := range rs { + httpOpts = append(httpOpts, httpOption{ + method: k, + path: v.(string), + }) + } + // turn the map into a slice and sort it to make sure getting the results in the same order every time + sort.Sort(httpOpts) + + var handlerOutDir string + genPath := getCompatibleAnnotation(m.GetOptions(), api.E_HandlerPath, api.E_HandlerPathCompatible) + handlerOutDir, ok := genPath.(string) + if !ok || len(handlerOutDir) == 0 { + handlerOutDir = "" + } + if len(handlerOutDir) == 0 { + handlerOutDir = servicePath + } + + // protoGoInfo can get generated "Go Info" for proto file. + // the type name may be different between "***.proto" and "***.pb.go" + protoGoInfo, exist := gen.FilesByPath[ast.GetName()] + if !exist { + return nil, fmt.Errorf("file(%s) can not exist", ast.GetName()) + } + methodGoInfo, err := getMethod(protoGoInfo, m) + if err != nil { + return nil, err + } + inputGoType := methodGoInfo.Input + outputGoType := methodGoInfo.Output + + reqName := m.GetInputType() + sb, err := resolver.ResolveIdentifier(reqName) + if err != nil { + return nil, err + } + reqName = util.BaseName(sb.Scope.GetOptions().GetGoPackage(), "") + "." + inputGoType.GoIdent.GoName + reqRawName := inputGoType.GoIdent.GoName + reqPackage := util.BaseName(sb.Scope.GetOptions().GetGoPackage(), "") + respName := m.GetOutputType() + st, err := resolver.ResolveIdentifier(respName) + if err != nil { + return nil, err + } + respName = util.BaseName(st.Scope.GetOptions().GetGoPackage(), "") + "." + outputGoType.GoIdent.GoName + respRawName := outputGoType.GoIdent.GoName + respPackage := util.BaseName(sb.Scope.GetOptions().GetGoPackage(), "") + + var serializer string + sl, sv := checkFirstOptions(SerializerOptions, m.GetOptions()) + if sl != "" { + serializer = sv.(string) + } + + method := &generator.HttpMethod{ + Name: util.CamelString(m.GetName()), + HTTPMethod: httpOpts[0].method, + Path: httpOpts[0].path, + Serializer: serializer, + OutputDir: handlerOutDir, + GenHandler: true, + } + + goOptMapAlias := make(map[string]string, 1) + refs := resolver.ExportReferred(false, true) + method.Models = make(map[string]*model.Model, len(refs)) + for _, ref := range refs { + if val, exist := method.Models[ref.Model.PackageName]; exist { + if val.Package == ref.Model.Package { + method.Models[ref.Model.PackageName] = ref.Model + goOptMapAlias[ref.Model.Package] = ref.Model.PackageName + } else { + file := filepath.Base(ref.Model.FilePath) + fileName := strings.Split(file, ".") + newPkg := fileName[len(fileName)-2] + "_" + val.PackageName + method.Models[newPkg] = ref.Model + goOptMapAlias[ref.Model.Package] = newPkg + } + continue + } + method.Models[ref.Model.PackageName] = ref.Model + goOptMapAlias[ref.Model.Package] = ref.Model.PackageName + } + merges = service.Models + merges.MergeMap(method.Models) + if goOptMapAlias[sb.Scope.GetOptions().GetGoPackage()] != "" { + reqName = goOptMapAlias[sb.Scope.GetOptions().GetGoPackage()] + "." + inputGoType.GoIdent.GoName + } + if goOptMapAlias[sb.Scope.GetOptions().GetGoPackage()] != "" { + respName = goOptMapAlias[st.Scope.GetOptions().GetGoPackage()] + "." + outputGoType.GoIdent.GoName + } + method.RequestTypeName = reqName + method.RequestTypeRawName = reqRawName + method.RequestTypePackage = reqPackage + method.ReturnTypeName = respName + method.ReturnTypeRawName = respRawName + method.ReturnTypePackage = respPackage + + methods = append(methods, method) + for idx, anno := range httpOpts { + if idx == 0 { + continue + } + tmp := *method + tmp.HTTPMethod = anno.method + tmp.Path = anno.path + tmp.GenHandler = false + methods = append(methods, &tmp) + } + + if cmdType == meta.CmdClient { + clientMethod := &generator.ClientMethod{} + clientMethod.HttpMethod = method + err := parseAnnotationToClient(clientMethod, gen, ast, m) + if err != nil { + return nil, err + } + clientMethods = append(clientMethods, clientMethod) + } + } + + service.ClientMethods = clientMethods + service.Methods = methods + service.Models = merges + out = append(out, service) + } + return out, nil +} + +func getCompatibleAnnotation(options proto.Message, anno, compatibleAnno *protoimpl.ExtensionInfo) interface{} { + if proto.HasExtension(options, anno) { + return checkFirstOption(anno, options) + } else if proto.HasExtension(options, compatibleAnno) { + return checkFirstOption(compatibleAnno, options) + } + + return nil +} + +func parseAnnotationToClient(clientMethod *generator.ClientMethod, gen *protogen.Plugin, ast *descriptorpb.FileDescriptorProto, m *descriptorpb.MethodDescriptorProto) error { + file, exist := gen.FilesByPath[ast.GetName()] + if !exist { + return fmt.Errorf("file(%s) can not exist", ast.GetName()) + } + method, err := getMethod(file, m) + if err != nil { + return err + } + // pb input type must be message + inputType := method.Input + var ( + hasBodyAnnotation bool + hasFormAnnotation bool + ) + for _, f := range inputType.Fields { + hasAnnotation := false + isStringFieldType := false + if f.Desc.Kind() == protoreflect.StringKind { + isStringFieldType = true + } + if proto.HasExtension(f.Desc.Options(), api.E_Query) { + hasAnnotation = true + queryAnnos := proto.GetExtension(f.Desc.Options(), api.E_Query) + val := checkSnakeName(queryAnnos.(string)) + clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName) + } + + if proto.HasExtension(f.Desc.Options(), api.E_Path) { + hasAnnotation = true + pathAnnos := proto.GetExtension(f.Desc.Options(), api.E_Path) + val := checkSnakeName(pathAnnos.(string)) + if isStringFieldType { + clientMethod.PathParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName) + } else { + clientMethod.PathParamsCode += fmt.Sprintf("%q: fmt.Sprint(req.Get%s()),\n", val, f.GoName) + } + } + + if proto.HasExtension(f.Desc.Options(), api.E_Header) { + hasAnnotation = true + headerAnnos := proto.GetExtension(f.Desc.Options(), api.E_Header) + val := checkSnakeName(headerAnnos.(string)) + if isStringFieldType { + clientMethod.HeaderParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName) + } else { + clientMethod.HeaderParamsCode += fmt.Sprintf("%q: fmt.Sprint(req.Get%s()),\n", val, f.GoName) + } + } + + if formAnnos := getCompatibleAnnotation(f.Desc.Options(), api.E_Form, api.E_FormCompatible); formAnnos != nil { + hasAnnotation = true + hasFormAnnotation = true + val := checkSnakeName(formAnnos.(string)) + if isStringFieldType { + clientMethod.FormValueCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName) + } else { + clientMethod.FormValueCode += fmt.Sprintf("%q: fmt.Sprint(req.Get%s()),\n", val, f.GoName) + } + } + + if proto.HasExtension(f.Desc.Options(), api.E_Body) { + hasAnnotation = true + hasBodyAnnotation = true + } + + if fileAnnos := getCompatibleAnnotation(f.Desc.Options(), api.E_FileName, api.E_FileNameCompatible); fileAnnos != nil { + hasAnnotation = true + hasFormAnnotation = true + val := checkSnakeName(fileAnnos.(string)) + clientMethod.FormFileCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName) + } + if !hasAnnotation && strings.EqualFold(clientMethod.HTTPMethod, "get") { + clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", checkSnakeName(string(f.Desc.Name())), f.GoName) + } + } + clientMethod.BodyParamsCode = meta.SetBodyParam + if hasBodyAnnotation && hasFormAnnotation { + clientMethod.FormValueCode = "" + clientMethod.FormFileCode = "" + } + if !hasBodyAnnotation && hasFormAnnotation { + clientMethod.BodyParamsCode = "" + } + + return nil +} + +func getMethod(file *protogen.File, m *descriptorpb.MethodDescriptorProto) (*protogen.Method, error) { + for _, f := range file.Services { + for _, method := range f.Methods { + if string(method.Desc.Name()) == m.GetName() { + return method, nil + } + } + } + + return nil, fmt.Errorf("can not find method: %s", m.GetName()) +} + +//---------------------------------Model-------------------------------- + +func astToModel(ast *descriptorpb.FileDescriptorProto, rs *Resolver) (*model.Model, error) { + main := rs.mainPkg.Model + if main == nil { + main = new(model.Model) + } + + mainFileDes := rs.files.PbReflect[ast.GetName()] + isProto3 := mainFileDes.IsProto3() + // Enums + ems := ast.GetEnumType() + enums := make([]model.Enum, 0, len(ems)) + for _, e := range ems { + em := model.Enum{ + Scope: main, + Name: e.GetName(), + GoType: "int32", + } + es := e.GetValue() + vs := make([]model.Constant, 0, len(es)) + for _, ee := range es { + vs = append(vs, model.Constant{ + Scope: main, + Name: ee.GetName(), + Type: model.TypeInt32, + Value: model.IntExpression{Src: int(ee.GetNumber())}, + }) + } + em.Values = vs + enums = append(enums, em) + } + main.Enums = enums + + // Structs + sts := ast.GetMessageType() + structs := make([]model.Struct, 0, len(sts)*2) + oneofs := make([]model.Oneof, 0, 1) + for _, st := range sts { + stMessage := mainFileDes.FindMessage(ast.GetPackage() + "." + st.GetName()) + stLeadingComments := getMessageLeadingComments(stMessage) + s := model.Struct{ + Scope: main, + Name: st.GetName(), + Category: model.CategoryStruct, + LeadingComments: stLeadingComments, + } + + ns := st.GetNestedType() + nestedMessageInfoMap := getNestedMessageInfoMap(stMessage) + for _, nt := range ns { + if IsMapEntry(nt) { + continue + } + + nestedMessageInfo := nestedMessageInfoMap[nt.GetName()] + nestedMessageLeadingComment := getMessageLeadingComments(nestedMessageInfo) + s := model.Struct{ + Scope: main, + Name: st.GetName() + "_" + nt.GetName(), + Category: model.CategoryStruct, + LeadingComments: nestedMessageLeadingComment, + } + fs := nt.GetField() + ns := nt.GetNestedType() + vs := make([]model.Field, 0, len(fs)) + + oneofMap := make(map[string]model.Field) + oneofType, err := resolveOneof(nestedMessageInfo, oneofMap, rs, isProto3, s, ns) + if err != nil { + return nil, err + } + oneofs = append(oneofs, oneofType...) + + choiceSet := make(map[string]bool) + + for _, f := range fs { + if field, exist := oneofMap[f.GetName()]; exist { + if _, ex := choiceSet[field.Name]; !ex { + choiceSet[field.Name] = true + vs = append(vs, field) + } + continue + } + dv := f.GetDefaultValue() + fieldLeadingComments, fieldTrailingComments := getFiledComments(f, nestedMessageInfo) + t, err := rs.ResolveType(f, ns) + if err != nil { + return nil, err + } + field := model.Field{ + Scope: &s, + Name: util.CamelString(f.GetName()), + Type: t, + LeadingComments: fieldLeadingComments, + TrailingComments: fieldTrailingComments, + IsPointer: isPointer(f, isProto3), + } + if dv != "" { + field.IsSetDefault = true + field.DefaultValue, err = parseDefaultValue(f.GetType(), f.GetDefaultValue()) + if err != nil { + return nil, err + } + } + err = injectTagsToModel(f, &field, true) + if err != nil { + return nil, err + } + vs = append(vs, field) + } + checkDuplicatedFileName(vs) + s.Fields = vs + structs = append(structs, s) + } + + fs := st.GetField() + vs := make([]model.Field, 0, len(fs)) + + oneofMap := make(map[string]model.Field) + oneofType, err := resolveOneof(stMessage, oneofMap, rs, isProto3, s, ns) + if err != nil { + return nil, err + } + oneofs = append(oneofs, oneofType...) + + choiceSet := make(map[string]bool) + + for _, f := range fs { + if field, exist := oneofMap[f.GetName()]; exist { + if _, ex := choiceSet[field.Name]; !ex { + choiceSet[field.Name] = true + vs = append(vs, field) + } + continue + } + dv := f.GetDefaultValue() + fieldLeadingComments, fieldTrailingComments := getFiledComments(f, stMessage) + t, err := rs.ResolveType(f, ns) + if err != nil { + return nil, err + } + field := model.Field{ + Scope: &s, + Name: util.CamelString(f.GetName()), + Type: t, + LeadingComments: fieldLeadingComments, + TrailingComments: fieldTrailingComments, + IsPointer: isPointer(f, isProto3), + } + if dv != "" { + field.IsSetDefault = true + field.DefaultValue, err = parseDefaultValue(f.GetType(), f.GetDefaultValue()) + if err != nil { + return nil, err + } + } + err = injectTagsToModel(f, &field, true) + if err != nil { + return nil, err + } + vs = append(vs, field) + } + checkDuplicatedFileName(vs) + s.Fields = vs + structs = append(structs, s) + + } + main.Oneofs = oneofs + main.Structs = structs + + // In case of only the service refers another model, therefore scanning service is necessary + ss := ast.GetService() + for _, s := range ss { + ms := s.GetMethod() + for _, m := range ms { + _, err := rs.ResolveIdentifier(m.GetInputType()) + if err != nil { + return nil, err + } + _, err = rs.ResolveIdentifier(m.GetOutputType()) + if err != nil { + return nil, err + } + } + } + + return main, nil +} + +// getMessageLeadingComments can get struct LeadingComment +func getMessageLeadingComments(stMessage *desc.MessageDescriptor) string { + if stMessage == nil { + return "" + } + stComments := stMessage.GetSourceInfo().GetLeadingComments() + stComments = formatComments(stComments) + + return stComments +} + +// getFiledComments can get field LeadingComments and field TailingComments for field +func getFiledComments(f *descriptorpb.FieldDescriptorProto, stMessage *desc.MessageDescriptor) (string, string) { + if stMessage == nil { + return "", "" + } + + fieldNum := f.GetNumber() + field := stMessage.FindFieldByNumber(fieldNum) + fieldInfo := field.GetSourceInfo() + + fieldLeadingComments := fieldInfo.GetLeadingComments() + fieldTailingComments := fieldInfo.GetTrailingComments() + + fieldLeadingComments = formatComments(fieldLeadingComments) + fieldTailingComments = formatComments(fieldTailingComments) + + return fieldLeadingComments, fieldTailingComments +} + +// formatComments can format the comments for beauty +func formatComments(comments string) string { + if len(comments) == 0 { + return "" + } + + comments = util.TrimLastChar(comments) + comments = util.AddSlashForComments(comments) + + return comments +} + +// getNestedMessageInfoMap can get all nested struct +func getNestedMessageInfoMap(stMessage *desc.MessageDescriptor) map[string]*desc.MessageDescriptor { + nestedMessage := stMessage.GetNestedMessageTypes() + nestedMessageInfoMap := make(map[string]*desc.MessageDescriptor, len(nestedMessage)) + + for _, nestedMsg := range nestedMessage { + nestedMsgName := nestedMsg.GetName() + nestedMessageInfoMap[nestedMsgName] = nestedMsg + } + + return nestedMessageInfoMap +} + +func parseDefaultValue(typ descriptorpb.FieldDescriptorProto_Type, val string) (model.Literal, error) { + switch typ { + case descriptorpb.FieldDescriptorProto_TYPE_BYTES, descriptorpb.FieldDescriptorProto_TYPE_STRING: + return model.StringExpression{Src: val}, nil + case descriptorpb.FieldDescriptorProto_TYPE_BOOL: + return model.BoolExpression{Src: val == "true"}, nil + case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE, + descriptorpb.FieldDescriptorProto_TYPE_FLOAT, + descriptorpb.FieldDescriptorProto_TYPE_INT64, + descriptorpb.FieldDescriptorProto_TYPE_UINT64, + descriptorpb.FieldDescriptorProto_TYPE_INT32, + descriptorpb.FieldDescriptorProto_TYPE_FIXED64, + descriptorpb.FieldDescriptorProto_TYPE_FIXED32, + descriptorpb.FieldDescriptorProto_TYPE_UINT32, + descriptorpb.FieldDescriptorProto_TYPE_ENUM, + descriptorpb.FieldDescriptorProto_TYPE_SFIXED32, + descriptorpb.FieldDescriptorProto_TYPE_SFIXED64, + descriptorpb.FieldDescriptorProto_TYPE_SINT32, + descriptorpb.FieldDescriptorProto_TYPE_SINT64: + return model.NumberExpression{Src: val}, nil + default: + return nil, fmt.Errorf("unsupported type %s", typ.String()) + } +} + +func isPointer(f *descriptorpb.FieldDescriptorProto, isProto3 bool) bool { + if f.GetType() == descriptorpb.FieldDescriptorProto_TYPE_MESSAGE || f.GetType() == descriptorpb.FieldDescriptorProto_TYPE_BYTES { + return false + } + + if !isProto3 { + if f.GetLabel() == descriptorpb.FieldDescriptorProto_LABEL_REPEATED { + return false + } + return true + } + + switch f.GetLabel() { + case descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL: + if !f.GetProto3Optional() { + return false + } + return true + default: + return false + } +} + +func resolveOneof(stMessage *desc.MessageDescriptor, oneofMap map[string]model.Field, rs *Resolver, isProto3 bool, s model.Struct, ns []*descriptorpb.DescriptorProto) ([]model.Oneof, error) { + oneofs := make([]model.Oneof, 0, 1) + if len(stMessage.GetOneOfs()) != 0 { + for _, oneof := range stMessage.GetOneOfs() { + if isProto3 { + if oneof.IsSynthetic() { + continue + } + } + oneofName := oneof.GetName() + messageName := s.Name + typeName := "is" + messageName + "_" + oneofName + field := model.Field{ + Scope: &s, + Name: util.CamelString(oneofName), + Type: model.NewOneofType(typeName), + IsPointer: false, + } + + oneofComment := oneof.GetSourceInfo().GetLeadingComments() + oneofComment = formatComments(oneofComment) + var oneofLeadingComments string + if oneofComment == "" { + oneofLeadingComments = fmt.Sprintf(" Types that are assignable to %s:\n", oneofName) + } else { + oneofLeadingComments = fmt.Sprintf("%s\n//\n// Types that are assignable to %s:\n", oneofComment, oneofName) + } + for idx, ch := range oneof.GetChoices() { + if idx == len(oneof.GetChoices())-1 { + oneofLeadingComments = oneofLeadingComments + fmt.Sprintf("// *%s_%s", messageName, ch.GetName()) + } else { + oneofLeadingComments = oneofLeadingComments + fmt.Sprintf("// *%s_%s\n", messageName, ch.GetName()) + } + } + field.LeadingComments = oneofLeadingComments + + choices := make([]model.Choice, 0, len(oneof.GetChoices())) + for _, ch := range oneof.GetChoices() { + t, err := rs.ResolveType(ch.AsFieldDescriptorProto(), ns) + if err != nil { + return nil, err + } + choice := model.Choice{ + MessageName: messageName, + ChoiceName: ch.GetName(), + Type: t, + } + choices = append(choices, choice) + oneofMap[ch.GetName()] = field + } + + oneofType := model.Oneof{ + MessageName: messageName, + OneofName: oneofName, + InterfaceName: typeName, + Choices: choices, + } + + oneofs = append(oneofs, oneofType) + } + } + return oneofs, nil +} + +func getNewFieldName(fieldName string, fieldNameSet map[string]bool) string { + if _, ex := fieldNameSet[fieldName]; ex { + fieldName = fieldName + "_" + return getNewFieldName(fieldName, fieldNameSet) + } + return fieldName +} + +func checkDuplicatedFileName(vs []model.Field) { + fieldNameSet := make(map[string]bool) + for i := 0; i < len(vs); i++ { + if _, ex := fieldNameSet[vs[i].Name]; ex { + newName := getNewFieldName(vs[i].Name, fieldNameSet) + fieldNameSet[newName] = true + vs[i].Name = newName + } else { + fieldNameSet[vs[i].Name] = true + } + } +} diff --git a/protobuf/plugin.go b/protobuf/plugin.go new file mode 100644 index 0000000..2603340 --- /dev/null +++ b/protobuf/plugin.go @@ -0,0 +1,639 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Copyright (c) 2018 The Go Authors. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors. + */ + +package protobuf + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strings" + + "github.com/cloudwego/hertz/cmd/hz/config" + "github.com/cloudwego/hertz/cmd/hz/generator" + "github.com/cloudwego/hertz/cmd/hz/generator/model" + "github.com/cloudwego/hertz/cmd/hz/meta" + "github.com/cloudwego/hertz/cmd/hz/util" + "github.com/cloudwego/hertz/cmd/hz/util/logs" + gengo "google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo" + "google.golang.org/protobuf/compiler/protogen" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/runtime/protoimpl" + "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/pluginpb" +) + +type Plugin struct { + *protogen.Plugin + Package string + Recursive bool + OutDir string + ModelDir string + UseDir string + IdlClientDir string + RmTags RemoveTags + PkgMap map[string]string + logger *logs.StdLogger +} + +type RemoveTags []string + +func (rm *RemoveTags) Exist(tag string) bool { + for _, rmTag := range *rm { + if rmTag == tag { + return true + } + } + return false +} + +func (plugin *Plugin) Run() int { + plugin.setLogger() + args := &config.Argument{} + defer func() { + if args == nil { + return + } + if args.Verbose { + verboseLog := plugin.recvVerboseLogger() + if len(verboseLog) != 0 { + fmt.Fprintf(os.Stderr, verboseLog) + } + } else { + warning := plugin.recvWarningLogger() + if len(warning) != 0 { + fmt.Fprintf(os.Stderr, warning) + } + } + }() + // read protoc request + in, err := ioutil.ReadAll(os.Stdin) + if err != nil { + logs.Errorf("read request failed: %s\n", err.Error()) + return meta.PluginError + } + + req := &pluginpb.CodeGeneratorRequest{} + err = proto.Unmarshal(in, req) + if err != nil { + logs.Errorf("unmarshal request failed: %s\n", err.Error()) + return meta.PluginError + } + + args, err = plugin.parseArgs(*req.Parameter) + if err != nil { + logs.Errorf("parse args failed: %s\n", err.Error()) + return meta.PluginError + } + CheckTagOption(args) + // generate + err = plugin.Handle(req, args) + if err != nil { + logs.Errorf("generate failed: %s\n", err.Error()) + return meta.PluginError + } + return 0 +} + +func (plugin *Plugin) setLogger() { + plugin.logger = logs.NewStdLogger(logs.LevelInfo) + plugin.logger.Defer = true + plugin.logger.ErrOnly = true + logs.SetLogger(plugin.logger) +} + +func (plugin *Plugin) recvWarningLogger() string { + warns := plugin.logger.Warn() + plugin.logger.Flush() + logs.SetLogger(logs.NewStdLogger(logs.LevelInfo)) + return warns +} + +func (plugin *Plugin) recvVerboseLogger() string { + info := plugin.logger.Out() + warns := plugin.logger.Warn() + verboseLog := string(info) + warns + plugin.logger.Flush() + logs.SetLogger(logs.NewStdLogger(logs.LevelInfo)) + return verboseLog +} + +func (plugin *Plugin) parseArgs(param string) (*config.Argument, error) { + args := new(config.Argument) + params := strings.Split(param, ",") + err := args.Unpack(params) + if err != nil { + return nil, err + } + plugin.Package, err = args.GetGoPackage() + if err != nil { + return nil, err + } + plugin.Recursive = !args.NoRecurse + plugin.ModelDir, err = args.GetModelDir() + if err != nil { + return nil, err + } + plugin.OutDir = args.OutDir + plugin.PkgMap = args.OptPkgMap + plugin.UseDir = args.Use + return args, nil +} + +func (plugin *Plugin) Response(resp *pluginpb.CodeGeneratorResponse) error { + out, err := proto.Marshal(resp) + if err != nil { + return fmt.Errorf("marshal response failed: %s", err.Error()) + } + _, err = os.Stdout.Write(out) + if err != nil { + return fmt.Errorf("write response failed: %s", err.Error()) + } + return nil +} + +func (plugin *Plugin) Handle(req *pluginpb.CodeGeneratorRequest, args *config.Argument) error { + plugin.fixGoPackage(req, plugin.PkgMap) + + // new plugin + opts := protogen.Options{} + gen, err := opts.New(req) + plugin.Plugin = gen + plugin.RmTags = args.RmTags + if err != nil { + return fmt.Errorf("new protoc plugin failed: %s", err.Error()) + } + // plugin start working + err = plugin.GenerateFiles(gen) + if err != nil { + // Error within the plugin will be responded by the plugin. + // But if the plugin does not response correctly, the error is returned to the upper level. + err := fmt.Errorf("generate model file failed: %s", err.Error()) + gen.Error(err) + resp := gen.Response() + err2 := plugin.Response(resp) + if err2 != nil { + return err + } + return nil + } + + if args.CmdType == meta.CmdModel { + resp := gen.Response() + // plugin stop working + err = plugin.Response(resp) + if err != nil { + return fmt.Errorf("write response failed: %s", err.Error()) + } + + return nil + } + + files := gen.Request.ProtoFile + maps := make(map[string]*descriptorpb.FileDescriptorProto, len(files)) + for _, file := range files { + maps[file.GetName()] = file + } + main := maps[gen.Request.FileToGenerate[len(gen.Request.FileToGenerate)-1]] + deps := make(map[string]*descriptorpb.FileDescriptorProto, len(main.GetDependency())) + for _, dep := range main.GetDependency() { + if f, ok := maps[dep]; !ok { + err := fmt.Errorf("dependency file not found: %s", dep) + gen.Error(err) + resp := gen.Response() + err2 := plugin.Response(resp) + if err2 != nil { + return err + } + return nil + } else { + deps[dep] = f + } + } + + pkgFiles, err := plugin.genHttpPackage(main, deps, args) + if err != nil { + err := fmt.Errorf("generate package files failed: %s", err.Error()) + gen.Error(err) + resp := gen.Response() + err2 := plugin.Response(resp) + if err2 != nil { + return err + } + return nil + } + + // construct plugin response + resp := gen.Response() + // all files that need to be generated are returned to protoc + for _, pkgFile := range pkgFiles { + filePath := pkgFile.Path + content := pkgFile.Content + renderFile := &pluginpb.CodeGeneratorResponse_File{ + Name: &filePath, + Content: &content, + } + resp.File = append(resp.File, renderFile) + } + + // plugin stop working + err = plugin.Response(resp) + if err != nil { + return fmt.Errorf("write response failed: %s", err.Error()) + } + + return nil +} + +// fixGoPackage will update go_package to store all the model files in ${model_dir} +func (plugin *Plugin) fixGoPackage(req *pluginpb.CodeGeneratorRequest, pkgMap map[string]string) { + gopkg := plugin.Package + for _, f := range req.ProtoFile { + if strings.HasPrefix(f.GetPackage(), "google.protobuf") { + continue + } + opt := getGoPackage(f, pkgMap) + if !strings.Contains(opt, gopkg) { + if strings.HasPrefix(opt, "/") { + opt = gopkg + opt + } else { + opt = gopkg + "/" + opt + } + } + impt, _ := plugin.fixModelPathAndPackage(opt) + *f.Options.GoPackage = impt + } +} + +// fixModelPathAndPackage will modify the go_package to adapt the go_package of the hz, +// for example adding the go module and model dir. +func (plugin *Plugin) fixModelPathAndPackage(pkg string) (impt, path string) { + if strings.HasPrefix(pkg, plugin.Package) { + impt = util.ImportToPathAndConcat(pkg[len(plugin.Package):], "") + } + if plugin.ModelDir != "" && plugin.ModelDir != "." { + modelImpt := util.PathToImport(string(filepath.Separator)+plugin.ModelDir, "") + // trim model dir for go package + if strings.HasPrefix(impt, modelImpt) { + impt = impt[len(modelImpt):] + } + impt = util.PathToImport(plugin.ModelDir, "") + impt + } + path = util.ImportToPath(impt, "") + impt = plugin.Package + "/" + impt + if util.IsWindows() { + impt = util.PathToImport(impt, "") + } + return +} + +func (plugin *Plugin) GenerateFiles(pluginPb *protogen.Plugin) error { + idl := pluginPb.Request.FileToGenerate[len(pluginPb.Request.FileToGenerate)-1] + pluginPb.SupportedFeatures = gengo.SupportedFeatures + for _, f := range pluginPb.Files { + if f.Proto.GetName() == idl { + err := plugin.GenerateFile(pluginPb, f) + if err != nil { + return err + } + impt := string(f.GoImportPath) + if strings.HasPrefix(impt, plugin.Package) { + impt = impt[len(plugin.Package):] + } + plugin.IdlClientDir = impt + } else if plugin.Recursive { + if strings.HasPrefix(f.Proto.GetPackage(), "google.protobuf") { + continue + } + err := plugin.GenerateFile(pluginPb, f) + if err != nil { + return err + } + } + } + return nil +} + +func (plugin *Plugin) GenerateFile(gen *protogen.Plugin, f *protogen.File) error { + impt := string(f.GoImportPath) + if strings.HasPrefix(impt, plugin.Package) { + impt = impt[len(plugin.Package):] + } + f.GeneratedFilenamePrefix = filepath.Join(util.ImportToPath(impt, ""), util.BaseName(f.Proto.GetName(), ".proto")) + f.Generate = true + // if use third-party model, no model code is generated within the project + if len(plugin.UseDir) != 0 { + return nil + } + file, err := generateFile(gen, f, plugin.RmTags) + if err != nil || file == nil { + return fmt.Errorf("generate file %s failed: %s", f.Proto.GetName(), err.Error()) + } + return nil +} + +// generateFile generates the contents of a .pb.go file. +func generateFile(gen *protogen.Plugin, file *protogen.File, rmTags RemoveTags) (*protogen.GeneratedFile, error) { + filename := file.GeneratedFilenamePrefix + ".pb.go" + g := gen.NewGeneratedFile(filename, file.GoImportPath) + f := newFileInfo(file) + + genStandaloneComments(g, f, int32(FileDescriptorProto_Syntax_field_number)) + genGeneratedHeader(gen, g, f) + genStandaloneComments(g, f, int32(FileDescriptorProto_Package_field_number)) + + packageDoc := genPackageKnownComment(f) + g.P(packageDoc, "package ", f.GoPackageName) + g.P() + + // Emit a static check that enforces a minimum version of the proto package. + if gengo.GenerateVersionMarkers { + g.P("const (") + g.P("// Verify that this generated code is sufficiently up-to-date.") + g.P("_ = ", protoimplPackage.Ident("EnforceVersion"), "(", protoimpl.GenVersion, " - ", protoimplPackage.Ident("MinVersion"), ")") + g.P("// Verify that runtime/protoimpl is sufficiently up-to-date.") + g.P("_ = ", protoimplPackage.Ident("EnforceVersion"), "(", protoimplPackage.Ident("MaxVersion"), " - ", protoimpl.GenVersion, ")") + g.P(")") + g.P() + } + + for i, imps := 0, f.Desc.Imports(); i < imps.Len(); i++ { + genImport(gen, g, f, imps.Get(i)) + } + for _, enum := range f.allEnums { + genEnum(g, f, enum) + } + var err error + for _, message := range f.allMessages { + err = genMessage(g, f, message, rmTags) + if err != nil { + return nil, err + } + } + genExtensions(g, f) + + genReflectFileDescriptor(gen, g, f) + + return g, nil +} + +func genMessage(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, rmTags RemoveTags) error { + if m.Desc.IsMapEntry() { + return nil + } + + // Message type declaration. + g.Annotate(m.GoIdent.GoName, m.Location) + leadingComments := appendDeprecationSuffix(m.Comments.Leading, + m.Desc.Options().(*descriptorpb.MessageOptions).GetDeprecated()) + g.P(leadingComments, + "type ", m.GoIdent, " struct {") + err := genMessageFields(g, f, m, rmTags) + if err != nil { + return err + } + g.P("}") + g.P() + + genMessageKnownFunctions(g, f, m) + genMessageDefaultDecls(g, f, m) + genMessageMethods(g, f, m) + genMessageOneofWrapperTypes(g, f, m) + return nil +} + +func genMessageFields(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, rmTags RemoveTags) error { + sf := f.allMessageFieldsByPtr[m] + genMessageInternalFields(g, f, m, sf) + var err error + for _, field := range m.Fields { + err = genMessageField(g, f, m, field, sf, rmTags) + if err != nil { + return err + } + } + return nil +} + +func genMessageField(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, field *protogen.Field, sf *structFields, rmTags RemoveTags) error { + if oneof := field.Oneof; oneof != nil && !oneof.Desc.IsSynthetic() { + // It would be a bit simpler to iterate over the oneofs below, + // but generating the field here keeps the contents of the Go + // struct in the same order as the contents of the source + // .proto file. + if oneof.Fields[0] != field { + return nil // only generate for first appearance + } + + tags := structTags{ + {"protobuf_oneof", string(oneof.Desc.Name())}, + } + if m.isTracked { + tags = append(tags, gotrackTags...) + } + + g.Annotate(m.GoIdent.GoName+"."+oneof.GoName, oneof.Location) + leadingComments := oneof.Comments.Leading + if leadingComments != "" { + leadingComments += "\n" + } + ss := []string{fmt.Sprintf(" Types that are assignable to %s:\n", oneof.GoName)} + for _, field := range oneof.Fields { + ss = append(ss, "\t*"+field.GoIdent.GoName+"\n") + } + leadingComments += protogen.Comments(strings.Join(ss, "")) + g.P(leadingComments, + oneof.GoName, " ", oneofInterfaceName(oneof), tags) + sf.append(oneof.GoName) + return nil + } + goType, pointer := fieldGoType(g, f, field) + if pointer { + goType = "*" + goType + } + tags := structTags{ + {"protobuf", fieldProtobufTagValue(field)}, + //{"json", fieldJSONTagValue(field)}, + } + if field.Desc.IsMap() { + key := field.Message.Fields[0] + val := field.Message.Fields[1] + tags = append(tags, structTags{ + {"protobuf_key", fieldProtobufTagValue(key)}, + {"protobuf_val", fieldProtobufTagValue(val)}, + }...) + } + + err := injectTagsToStructTags(field.Desc, &tags, true, rmTags) + if err != nil { + return err + } + + if m.isTracked { + tags = append(tags, gotrackTags...) + } + + name := field.GoName + if field.Desc.IsWeak() { + name = WeakFieldPrefix_goname + name + } + g.Annotate(m.GoIdent.GoName+"."+name, field.Location) + leadingComments := appendDeprecationSuffix(field.Comments.Leading, + field.Desc.Options().(*descriptorpb.FieldOptions).GetDeprecated()) + g.P(leadingComments, + name, " ", goType, tags, + trailingComment(field.Comments.Trailing)) + sf.append(field.GoName) + return nil +} + +func (plugin *Plugin) getIdlInfo(ast *descriptorpb.FileDescriptorProto, deps map[string]*descriptorpb.FileDescriptorProto, args *config.Argument) (*generator.HttpPackage, error) { + if ast == nil { + return nil, fmt.Errorf("ast is nil") + } + + pkg := getGoPackage(ast, map[string]string{}) + main := &model.Model{ + FilePath: ast.GetName(), + Package: pkg, + PackageName: util.BaseName(pkg, ""), + } + fileInfo := FileInfos{ + Official: deps, + PbReflect: nil, + } + rs, err := NewResolver(ast, fileInfo, main, map[string]string{}) + if err != nil { + return nil, fmt.Errorf("new protobuf resolver failed, err:%v", err) + } + err = rs.LoadAll(ast) + if err != nil { + return nil, err + } + + services, err := astToService(ast, rs, args.CmdType, plugin.Plugin) + if err != nil { + return nil, err + } + var models model.Models + for _, s := range services { + models.MergeArray(s.Models) + } + + return &generator.HttpPackage{ + Services: services, + IdlName: ast.GetName(), + Package: util.BaseName(pkg, ""), + Models: models, + }, nil +} + +func (plugin *Plugin) genHttpPackage(ast *descriptorpb.FileDescriptorProto, deps map[string]*descriptorpb.FileDescriptorProto, args *config.Argument) ([]generator.File, error) { + options := CheckTagOption(args) + idl, err := plugin.getIdlInfo(ast, deps, args) + if err != nil { + return nil, err + } + + customPackageTemplate := args.CustomizePackage + pkg, err := args.GetGoPackage() + if err != nil { + return nil, err + } + handlerDir, err := args.GetHandlerDir() + if err != nil { + return nil, err + } + routerDir, err := args.GetRouterDir() + if err != nil { + return nil, err + } + modelDir, err := args.GetModelDir() + if err != nil { + return nil, err + } + clientDir, err := args.GetClientDir() + if err != nil { + return nil, err + } + sg := generator.HttpPackageGenerator{ + ConfigPath: customPackageTemplate, + HandlerDir: handlerDir, + RouterDir: routerDir, + ModelDir: modelDir, + UseDir: args.Use, + ClientDir: clientDir, + TemplateGenerator: generator.TemplateGenerator{ + OutputDir: args.OutDir, + Excludes: args.Excludes, + }, + ProjPackage: pkg, + Options: options, + HandlerByMethod: args.HandlerByMethod, + CmdType: args.CmdType, + IdlClientDir: plugin.IdlClientDir, + ForceClientDir: args.ForceClientDir, + BaseDomain: args.BaseDomain, + SnakeStyleMiddleware: args.SnakeStyleMiddleware, + } + + if args.ModelBackend != "" { + sg.Backend = meta.Backend(args.ModelBackend) + } + generator.SetDefaultTemplateConfig() + + err = sg.Generate(idl) + if err != nil { + return nil, fmt.Errorf("generate http package error: %v", err) + } + files, err := sg.GetFormatAndExcludedFiles() + if err != nil { + return nil, fmt.Errorf("persist http package error: %v", err) + } + return files, nil +} diff --git a/protobuf/plugin_stubs.go b/protobuf/plugin_stubs.go new file mode 100644 index 0000000..7c4b40b --- /dev/null +++ b/protobuf/plugin_stubs.go @@ -0,0 +1,232 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Copyright (c) 2018 The Go Authors. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors. + */ + +package protobuf + +import ( + "fmt" + "strconv" + "strings" + "unicode" + "unicode/utf8" + _ "unsafe" + + "google.golang.org/protobuf/compiler/protogen" + "google.golang.org/protobuf/reflect/protoreflect" +) + +// Field numbers for google.protobuf.FileDescriptorProto. +const ( + FileDescriptorProto_Name_field_number protoreflect.FieldNumber = 1 + FileDescriptorProto_Package_field_number protoreflect.FieldNumber = 2 + FileDescriptorProto_Dependency_field_number protoreflect.FieldNumber = 3 + FileDescriptorProto_PublicDependency_field_number protoreflect.FieldNumber = 10 + FileDescriptorProto_WeakDependency_field_number protoreflect.FieldNumber = 11 + FileDescriptorProto_MessageType_field_number protoreflect.FieldNumber = 4 + FileDescriptorProto_EnumType_field_number protoreflect.FieldNumber = 5 + FileDescriptorProto_Service_field_number protoreflect.FieldNumber = 6 + FileDescriptorProto_Extension_field_number protoreflect.FieldNumber = 7 + FileDescriptorProto_Options_field_number protoreflect.FieldNumber = 8 + FileDescriptorProto_SourceCodeInfo_field_number protoreflect.FieldNumber = 9 + FileDescriptorProto_Syntax_field_number protoreflect.FieldNumber = 12 +) + +const WeakFieldPrefix_goname = "XXX_weak_" + +type fileInfo struct { + *protogen.File + + allEnums []*enumInfo + allMessages []*messageInfo + allExtensions []*extensionInfo + + allEnumsByPtr map[*enumInfo]int // value is index into allEnums + allMessagesByPtr map[*messageInfo]int // value is index into allMessages + allMessageFieldsByPtr map[*messageInfo]*structFields + + // needRawDesc specifies whether the generator should emit logic to provide + // the legacy raw descriptor in GZIP'd form. + // This is updated by enum and message generation logic as necessary, + // and checked at the end of file generation. + needRawDesc bool +} + +type enumInfo struct { + *protogen.Enum + + genJSONMethod bool + genRawDescMethod bool +} + +type messageInfo struct { + *protogen.Message + + genRawDescMethod bool + genExtRangeMethod bool + + isTracked bool + hasWeak bool +} + +type extensionInfo struct { + *protogen.Extension +} + +type structFields struct { + count int + unexported map[int]string +} + +func (sf *structFields) append(name string) { + if r, _ := utf8.DecodeRuneInString(name); !unicode.IsUpper(r) { + if sf.unexported == nil { + sf.unexported = make(map[int]string) + } + sf.unexported[sf.count] = name + } + sf.count++ +} + +type structTags [][2]string + +func (tags structTags) String() string { + if len(tags) == 0 { + return "" + } + var ss []string + for _, tag := range tags { + // NOTE: When quoting the value, we need to make sure the backtick + // character does not appear. Convert all cases to the escaped hex form. + key := tag[0] + val := strings.Replace(strconv.Quote(tag[1]), "`", `\x60`, -1) + ss = append(ss, fmt.Sprintf("%s:%s", key, val)) + } + return "`" + strings.Join(ss, " ") + "`" +} + +type goImportPath interface { + String() string + Ident(string) protogen.GoIdent +} + +type trailingComment protogen.Comments + +func (c trailingComment) String() string { + s := strings.TrimSuffix(protogen.Comments(c).String(), "\n") + if strings.Contains(s, "\n") { + // We don't support multi-lined trailing comments as it is unclear + // how to best render them in the generated code. + return "" + } + return s +} + +//go:linkname gotrackTags google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.gotrackTags +var gotrackTags structTags + +var ( + protoPackage goImportPath = protogen.GoImportPath("google.golang.org/protobuf/proto") + protoifacePackage goImportPath = protogen.GoImportPath("google.golang.org/protobuf/runtime/protoiface") + protoimplPackage goImportPath = protogen.GoImportPath("google.golang.org/protobuf/runtime/protoimpl") + protojsonPackage goImportPath = protogen.GoImportPath("google.golang.org/protobuf/encoding/protojson") + protoreflectPackage goImportPath = protogen.GoImportPath("google.golang.org/protobuf/reflect/protoreflect") + protoregistryPackage goImportPath = protogen.GoImportPath("google.golang.org/protobuf/reflect/protoregistry") +) + +//go:linkname newFileInfo google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.newFileInfo +func newFileInfo(file *protogen.File) *fileInfo + +//go:linkname genPackageKnownComment google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.genPackageKnownComment +func genPackageKnownComment(f *fileInfo) protogen.Comments + +//go:linkname genStandaloneComments google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.genStandaloneComments +func genStandaloneComments(g *protogen.GeneratedFile, f *fileInfo, n int32) + +//go:linkname genGeneratedHeader google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.genGeneratedHeader +func genGeneratedHeader(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo) + +//go:linkname genImport google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.genImport +func genImport(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo, imp protoreflect.FileImport) + +//go:linkname genEnum google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.genEnum +func genEnum(g *protogen.GeneratedFile, f *fileInfo, e *enumInfo) + +//go:linkname genMessageInternalFields google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.genMessageInternalFields +func genMessageInternalFields(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, sf *structFields) + +//go:linkname genExtensions google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.genExtensions +func genExtensions(g *protogen.GeneratedFile, f *fileInfo) + +//go:linkname genReflectFileDescriptor google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.genReflectFileDescriptor +func genReflectFileDescriptor(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo) + +//go:linkname appendDeprecationSuffix google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.appendDeprecationSuffix +func appendDeprecationSuffix(prefix protogen.Comments, deprecated bool) protogen.Comments + +//go:linkname genMessageDefaultDecls google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.genMessageDefaultDecls +func genMessageDefaultDecls(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) + +//go:linkname genMessageKnownFunctions google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.genMessageKnownFunctions +func genMessageKnownFunctions(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) + +//go:linkname genMessageMethods google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.genMessageMethods +func genMessageMethods(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) + +//go:linkname genMessageOneofWrapperTypes google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.genMessageOneofWrapperTypes +func genMessageOneofWrapperTypes(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) + +//go:linkname oneofInterfaceName google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.oneofInterfaceName +func oneofInterfaceName(oneof *protogen.Oneof) string + +//go:linkname fieldGoType google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.fieldGoType +func fieldGoType(g *protogen.GeneratedFile, f *fileInfo, field *protogen.Field) (goType string, pointer bool) + +//go:linkname fieldProtobufTagValue google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.fieldProtobufTagValue +func fieldProtobufTagValue(field *protogen.Field) string + +//go:linkname fieldJSONTagValue google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.fieldJSONTagValue +func fieldJSONTagValue(field *protogen.Field) string diff --git a/protobuf/plugin_test.go b/protobuf/plugin_test.go new file mode 100644 index 0000000..b34d8b8 --- /dev/null +++ b/protobuf/plugin_test.go @@ -0,0 +1,98 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package protobuf + +import ( + "io/ioutil" + "strings" + "testing" + + "github.com/cloudwego/hertz/cmd/hz/meta" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/pluginpb" +) + +func TestPlugin_Handle(t *testing.T) { + in, err := ioutil.ReadFile("../testdata/request_protoc.out") + if err != nil { + t.Fatal(err) + } + + req := &pluginpb.CodeGeneratorRequest{} + err = proto.Unmarshal(in, req) + if err != nil { + t.Fatalf("unmarshal stdin request error: %v", err) + } + + // prepare args + plu := &Plugin{} + plu.setLogger() + args, _ := plu.parseArgs(*req.Parameter) + + plu.Handle(req, args) + plu.recvWarningLogger() +} + +func TestFixModelPathAndPackage(t *testing.T) { + plu := &Plugin{} + plu.Package = "cloudwego/hertz" + plu.ModelDir = meta.ModelDir + // default model dir + ret1 := [][]string{ + {"a/b/c", "cloudwego/hertz/biz/model/a/b/c"}, + {"biz/model/a/b/c", "cloudwego/hertz/biz/model/a/b/c"}, + {"cloudwego/hertz/a/b/c", "cloudwego/hertz/biz/model/a/b/c"}, + {"cloudwego/hertz/biz/model/a/b/c", "cloudwego/hertz/biz/model/a/b/c"}, + } + for _, r := range ret1 { + tmp := r[0] + if !strings.Contains(tmp, plu.Package) { + if strings.HasPrefix(tmp, "/") { + tmp = plu.Package + tmp + } else { + tmp = plu.Package + "/" + tmp + } + } + result, _ := plu.fixModelPathAndPackage(tmp) + if result != r[1] { + t.Fatalf("want go package: %s, but get: %s", r[1], result) + } + } + + plu.ModelDir = "model_test" + // customized model dir + ret2 := [][]string{ + {"a/b/c", "cloudwego/hertz/model_test/a/b/c"}, + {"model_test/a/b/c", "cloudwego/hertz/model_test/a/b/c"}, + {"cloudwego/hertz/a/b/c", "cloudwego/hertz/model_test/a/b/c"}, + {"cloudwego/hertz/model_test/a/b/c", "cloudwego/hertz/model_test/a/b/c"}, + } + for _, r := range ret2 { + tmp := r[0] + if !strings.Contains(tmp, plu.Package) { + if strings.HasPrefix(tmp, "/") { + tmp = plu.Package + tmp + } else { + tmp = plu.Package + "/" + tmp + } + } + result, _ := plu.fixModelPathAndPackage(tmp) + if result != r[1] { + t.Fatalf("want go package: %s, but get: %s", r[1], result) + } + } +} diff --git a/protobuf/resolver.go b/protobuf/resolver.go new file mode 100644 index 0000000..94fe948 --- /dev/null +++ b/protobuf/resolver.go @@ -0,0 +1,530 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package protobuf + +import ( + "fmt" + "strings" + + "github.com/cloudwego/hertz/cmd/hz/generator/model" + "github.com/cloudwego/hertz/cmd/hz/util" + "github.com/jhump/protoreflect/desc" + "google.golang.org/protobuf/types/descriptorpb" +) + +type Symbol struct { + Space string + Name string + IsValue bool + Type *model.Type + Value interface{} + Scope *descriptorpb.FileDescriptorProto +} + +type NameSpace map[string]*Symbol + +var ( + ConstTrue = Symbol{ + IsValue: true, + Type: model.TypeBool, + Value: true, + Scope: &BaseProto, + } + ConstFalse = Symbol{ + IsValue: true, + Type: model.TypeBool, + Value: false, + Scope: &BaseProto, + } + ConstEmptyString = Symbol{ + IsValue: true, + Type: model.TypeString, + Value: "", + Scope: &BaseProto, + } +) + +type PackageReference struct { + IncludeBase string + IncludePath string + Model *model.Model + Ast *descriptorpb.FileDescriptorProto + Referred bool +} + +func getReferPkgMap(pkgMap map[string]string, incs []*descriptorpb.FileDescriptorProto, mainModel *model.Model) (map[string]*PackageReference, error) { + var err error + out := make(map[string]*PackageReference, len(pkgMap)) + pkgAliasMap := make(map[string]string, len(incs)) + // bugfix: add main package to avoid namespace conflict + mainPkg := mainModel.Package + mainPkgName := mainModel.PackageName + mainPkgName, err = util.GetPackageUniqueName(mainPkgName) + if err != nil { + return nil, err + } + pkgAliasMap[mainPkg] = mainPkgName + for _, inc := range incs { + pkg := getGoPackage(inc, pkgMap) + path := inc.GetName() + base := util.BaseName(path, ".proto") + fileName := inc.GetName() + pkgName := util.BaseName(pkg, "") + if pn, exist := pkgAliasMap[pkg]; exist { + pkgName = pn + } else { + pkgName, err = util.GetPackageUniqueName(pkgName) + pkgAliasMap[pkg] = pkgName + if err != nil { + return nil, fmt.Errorf("get package unique name failed, err: %v", err) + } + } + out[fileName] = &PackageReference{base, path, &model.Model{ + FilePath: path, + Package: pkg, + PackageName: pkgName, + }, inc, false} + } + + return out, nil +} + +type FileInfos struct { + Official map[string]*descriptorpb.FileDescriptorProto + PbReflect map[string]*desc.FileDescriptor +} + +type Resolver struct { + // idl symbols + rootName string + root NameSpace + deps map[string]NameSpace + + // exported models + mainPkg PackageReference + refPkgs map[string]*PackageReference + + files FileInfos +} + +func updateFiles(fileName string, files FileInfos) (FileInfos, error) { + file, _ := files.PbReflect[fileName] + if file == nil { + return FileInfos{}, fmt.Errorf("%s not found", fileName) + } + fileDep := file.GetDependencies() + + maps := make(map[string]*descriptorpb.FileDescriptorProto, len(fileDep)+1) + sourceInfoMap := make(map[string]*desc.FileDescriptor, len(fileDep)+1) + for _, dep := range fileDep { + ast := dep.AsFileDescriptorProto() + maps[dep.GetName()] = ast + sourceInfoMap[dep.GetName()] = dep + } + ast := file.AsFileDescriptorProto() + maps[file.GetName()] = ast + sourceInfoMap[file.GetName()] = file + + newFileInfo := FileInfos{ + Official: maps, + PbReflect: sourceInfoMap, + } + + return newFileInfo, nil +} + +func NewResolver(ast *descriptorpb.FileDescriptorProto, files FileInfos, model *model.Model, pkgMap map[string]string) (*Resolver, error) { + file := ast.GetName() + deps := ast.GetDependency() + var err error + if files.PbReflect != nil { + files, err = updateFiles(file, files) + if err != nil { + return nil, err + } + } + incs := make([]*descriptorpb.FileDescriptorProto, 0, len(deps)) + for _, dep := range deps { + if v, ok := files.Official[dep]; ok { + incs = append(incs, v) + } else { + return nil, fmt.Errorf("%s not found", dep) + } + } + pm, err := getReferPkgMap(pkgMap, incs, model) + if err != nil { + return nil, fmt.Errorf("get package map failed, err: %v", err) + } + return &Resolver{ + root: make(NameSpace), + deps: make(map[string]NameSpace), + refPkgs: pm, + files: files, + mainPkg: PackageReference{ + IncludeBase: util.BaseName(file, ".proto"), + IncludePath: file, + Model: model, + Ast: ast, + Referred: false, + }, + }, nil +} + +func (resolver *Resolver) GetRefModel(includeBase string) (*model.Model, error) { + if includeBase == "" { + return resolver.mainPkg.Model, nil + } + ref, ok := resolver.refPkgs[includeBase] + if !ok { + return nil, fmt.Errorf("%s not found", includeBase) + } + return ref.Model, nil +} + +func (resolver *Resolver) getBaseType(f *descriptorpb.FieldDescriptorProto, nested []*descriptorpb.DescriptorProto) (*model.Type, error) { + bt := switchBaseType(f.GetType()) + if bt != nil { + return checkListType(bt, f.GetLabel()), nil + } + + nt := getNestedType(f, nested) + if nt != nil { + fields := nt.GetField() + if IsMapEntry(nt) { + t := *model.TypeBaseMap + tk, err := resolver.ResolveType(fields[0], nt.GetNestedType()) + if err != nil { + return nil, err + } + tv, err := resolver.ResolveType(fields[1], nt.GetNestedType()) + if err != nil { + return nil, err + } + t.Extra = []*model.Type{tk, tv} + return &t, nil + } + } + return nil, nil +} + +func IsMapEntry(nt *descriptorpb.DescriptorProto) bool { + fields := nt.GetField() + return len(fields) == 2 && fields[0].GetName() == "key" && fields[1].GetName() == "value" +} + +func checkListType(typ *model.Type, label descriptorpb.FieldDescriptorProto_Label) *model.Type { + if label == descriptorpb.FieldDescriptorProto_LABEL_REPEATED { + t := *model.TypeBaseList + t.Extra = []*model.Type{typ} + return &t + } + return typ +} + +func getNestedType(f *descriptorpb.FieldDescriptorProto, nested []*descriptorpb.DescriptorProto) *descriptorpb.DescriptorProto { + tName := f.GetTypeName() + entry := util.SplitPackageName(tName, "") + for _, nt := range nested { + if nt.GetName() == entry { + return nt + } + } + return nil +} + +func (resolver *Resolver) ResolveType(f *descriptorpb.FieldDescriptorProto, nested []*descriptorpb.DescriptorProto) (*model.Type, error) { + bt, err := resolver.getBaseType(f, nested) + if err != nil { + return nil, err + } + if bt != nil { + return bt, nil + } + + tName := f.GetTypeName() + symbol, err := resolver.ResolveIdentifier(tName) + if err != nil { + return nil, err + } + deepType := checkListType(symbol.Type, f.GetLabel()) + return deepType, nil +} + +func (resolver *Resolver) ResolveIdentifier(id string) (ret *Symbol, err error) { + ret = resolver.Get(id) + if ret == nil { + return nil, fmt.Errorf("not found identifier %s", id) + } + + var ref *PackageReference + if _, ok := resolver.deps[ret.Space]; ok { + ref = resolver.refPkgs[ret.Scope.GetName()] + if ref != nil { + ref.Referred = true + ret.Type.Scope = ref.Model + } + } + // bugfix: root & dep file has the same package(namespace), the 'ret' will miss the namespace match for root. + // This results in a lack of dependencies in the generated handlers. + if ref == nil && ret.Scope == resolver.mainPkg.Ast { + resolver.mainPkg.Referred = true + ret.Type.Scope = resolver.mainPkg.Model + } + return +} + +func (resolver *Resolver) getFieldType(f *descriptorpb.FieldDescriptorProto, nested []*descriptorpb.DescriptorProto) (*model.Type, error) { + dt, err := resolver.getBaseType(f, nested) + if err != nil { + return nil, err + } + if dt != nil { + return dt, nil + } + sb := resolver.Get(f.GetTypeName()) + if sb != nil { + return sb.Type, nil + } + return nil, fmt.Errorf("not found type %s", f.GetTypeName()) +} + +func (resolver *Resolver) Get(name string) *Symbol { + if strings.HasPrefix(name, "."+resolver.rootName) { + id := strings.TrimPrefix(name, "."+resolver.rootName+".") + if v, ok := resolver.root[id]; ok { + return v + } + } + + // directly map first + var space string + if idx := strings.LastIndex(name, "."); idx >= 0 && idx < len(name)-1 { + space = strings.TrimLeft(name[:idx], ".") + } + if ns, ok := resolver.deps[space]; ok { + id := strings.TrimPrefix(name, "."+space+".") + if s, ok := ns[id]; ok { + return s + } + } + + // iterate check nested type in dependencies + for s, m := range resolver.deps { + if strings.HasPrefix(name, "."+s) { + id := strings.TrimPrefix(name, "."+s+".") + if s, ok := m[id]; ok { + return s + } + } + } + return nil +} + +func (resolver *Resolver) ExportReferred(all, needMain bool) (ret []*PackageReference) { + for _, v := range resolver.refPkgs { + if all { + ret = append(ret, v) + } else if v.Referred { + ret = append(ret, v) + } + v.Referred = false + } + + if needMain && (all || resolver.mainPkg.Referred) { + ret = append(ret, &resolver.mainPkg) + } + resolver.mainPkg.Referred = false + return +} + +func (resolver *Resolver) LoadAll(ast *descriptorpb.FileDescriptorProto) error { + var err error + resolver.root, err = resolver.LoadOne(ast) + if err != nil { + return fmt.Errorf("load main idl failed: %s", err) + } + resolver.rootName = ast.GetPackage() + + includes := ast.GetDependency() + astMap := make(map[string]NameSpace, len(includes)) + for _, dep := range includes { + file, ok := resolver.files.Official[dep] + if !ok { + return fmt.Errorf("not found included idl %s", dep) + } + depNamespace, err := resolver.LoadOne(file) + if err != nil { + return fmt.Errorf("load idl '%s' failed: %s", dep, err) + } + ns, existed := astMap[file.GetPackage()] + if existed { + depNamespace = mergeNamespace(ns, depNamespace) + } + astMap[file.GetPackage()] = depNamespace + } + resolver.deps = astMap + return nil +} + +func mergeNamespace(first, second NameSpace) NameSpace { + for k, v := range second { + if _, existed := first[k]; !existed { + first[k] = v + } + } + return first +} + +func LoadBaseIdentifier(ast *descriptorpb.FileDescriptorProto) map[string]*Symbol { + ret := make(NameSpace, len(ast.GetEnumType())+len(ast.GetMessageType())+len(ast.GetExtension())+len(ast.GetService())) + + ret["true"] = &ConstTrue + ret["false"] = &ConstFalse + ret[`""`] = &ConstEmptyString + ret["bool"] = &Symbol{ + Type: model.TypeBool, + Scope: ast, + } + ret["uint32"] = &Symbol{ + Type: model.TypeUint32, + Scope: ast, + } + ret["uint64"] = &Symbol{ + Type: model.TypeUint64, + Scope: ast, + } + ret["fixed32"] = &Symbol{ + Type: model.TypeUint32, + Scope: ast, + } + ret["fixed64"] = &Symbol{ + Type: model.TypeUint64, + Scope: ast, + } + ret["int32"] = &Symbol{ + Type: model.TypeInt32, + Scope: ast, + } + ret["int64"] = &Symbol{ + Type: model.TypeInt64, + Scope: ast, + } + ret["sint32"] = &Symbol{ + Type: model.TypeInt32, + Scope: ast, + } + ret["sint64"] = &Symbol{ + Type: model.TypeInt64, + Scope: ast, + } + ret["sfixed32"] = &Symbol{ + Type: model.TypeInt32, + Scope: ast, + } + ret["sfixed64"] = &Symbol{ + Type: model.TypeInt64, + Scope: ast, + } + ret["double"] = &Symbol{ + Type: model.TypeFloat64, + Scope: ast, + } + ret["float"] = &Symbol{ + Type: model.TypeFloat32, + Scope: ast, + } + ret["string"] = &Symbol{ + Type: model.TypeString, + Scope: ast, + } + ret["bytes"] = &Symbol{ + Type: model.TypeBinary, + Scope: ast, + } + return ret +} + +func (resolver *Resolver) LoadOne(ast *descriptorpb.FileDescriptorProto) (NameSpace, error) { + ret := LoadBaseIdentifier(ast) + space := util.BaseName(ast.GetPackage(), "") + prefix := "." + space + + for _, e := range ast.GetEnumType() { + name := strings.TrimLeft(e.GetName(), prefix) + ret[e.GetName()] = &Symbol{ + Name: name, + Space: space, + IsValue: false, + Value: e, + Scope: ast, + Type: model.NewEnumType(name, model.CategoryEnum), + } + for _, ee := range e.GetValue() { + name := strings.TrimLeft(ee.GetName(), prefix) + ret[ee.GetName()] = &Symbol{ + Name: name, + Space: space, + IsValue: true, + Value: ee, + Scope: ast, + Type: model.NewCategoryType(model.TypeInt, model.CategoryEnum), + } + } + } + + for _, mt := range ast.GetMessageType() { + name := strings.TrimLeft(mt.GetName(), prefix) + ret[mt.GetName()] = &Symbol{ + Name: name, + Space: space, + IsValue: false, + Value: mt, + Scope: ast, + Type: model.NewStructType(name, model.CategoryStruct), + } + + for _, nt := range mt.GetNestedType() { + ntname := name + "_" + nt.GetName() + ret[name+"."+nt.GetName()] = &Symbol{ + Name: ntname, + Space: space, + IsValue: false, + Value: nt, + Scope: ast, + Type: model.NewStructType(ntname, model.CategoryStruct), + } + } + } + + for _, s := range ast.GetService() { + name := strings.TrimLeft(s.GetName(), prefix) + ret[s.GetName()] = &Symbol{ + Name: name, + Space: space, + IsValue: false, + Value: s, + Scope: ast, + Type: model.NewFuncType(name, model.CategoryService), + } + } + + return ret, nil +} + +func (resolver *Resolver) GetFiles() FileInfos { + return resolver.files +} diff --git a/protobuf/tag_test.go b/protobuf/tag_test.go new file mode 100644 index 0000000..2e7a9ca --- /dev/null +++ b/protobuf/tag_test.go @@ -0,0 +1,153 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package protobuf + +import ( + "io/ioutil" + "strings" + "testing" + + "google.golang.org/protobuf/compiler/protogen" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/pluginpb" +) + +func TestTagGenerate(t *testing.T) { + type TagStruct struct { + Annotation string + GeneratedTag string + ActualTag string + } + + tagList := []TagStruct{ + { + Annotation: "query", + GeneratedTag: "protobuf:\"bytes,1,opt,name=QueryTag\" json:\"QueryTag,omitempty\" query:\"query\"", + }, + { + Annotation: "raw_body", + GeneratedTag: "protobuf:\"bytes,2,opt,name=RawBodyTag\" json:\"RawBodyTag,omitempty\" raw_body:\"raw_body\"", + }, + { + Annotation: "path", + GeneratedTag: "protobuf:\"bytes,3,opt,name=PathTag\" json:\"PathTag,omitempty\" path:\"path\"", + }, + { + Annotation: "form", + GeneratedTag: "protobuf:\"bytes,4,opt,name=FormTag\" json:\"FormTag,omitempty\" form:\"form\"", + }, + { + Annotation: "cookie", + GeneratedTag: "protobuf:\"bytes,5,opt,name=CookieTag\" json:\"CookieTag,omitempty\" cookie:\"cookie\"", + }, + { + Annotation: "header", + GeneratedTag: "protobuf:\"bytes,6,opt,name=HeaderTag\" json:\"HeaderTag,omitempty\" header:\"header\"", + }, + { + Annotation: "body", + GeneratedTag: "bytes,7,opt,name=BodyTag\" form:\"body\" json:\"body,omitempty\"", + }, + { + Annotation: "go.tag", + GeneratedTag: "bytes,8,opt,name=GoTag\" json:\"json\" form:\"form\" goTag:\"tag\" header:\"header\" query:\"query\"", + }, + { + Annotation: "vd", + GeneratedTag: "bytes,9,opt,name=VdTag\" json:\"VdTag,omitempty\" form:\"VdTag\" query:\"VdTag\" vd:\"$!='?'\"", + }, + { + Annotation: "non", + GeneratedTag: "bytes,10,opt,name=DefaultTag\" json:\"DefaultTag,omitempty\" form:\"DefaultTag\" query:\"DefaultTag\"", + }, + { + Annotation: "query required", + GeneratedTag: "bytes,11,req,name=ReqQuery\" json:\"ReqQuery,required\" query:\"query,required\"", + }, + { + Annotation: "query optional", + GeneratedTag: "bytes,12,opt,name=OptQuery\" json:\"OptQuery,omitempty\" query:\"query\"", + }, + { + Annotation: "body required", + GeneratedTag: "protobuf:\"bytes,13,req,name=ReqBody\" form:\"body,required\" json:\"body,required\"", + }, + { + Annotation: "body optional", + GeneratedTag: "protobuf:\"bytes,14,opt,name=OptBody\" form:\"body\" json:\"body,omitempty\"", + }, + { + Annotation: "go.tag required", + GeneratedTag: "protobuf:\"bytes,15,req,name=ReqGoTag\" query:\"ReqGoTag,required\" form:\"ReqGoTag,required\" json:\"json\"", + }, + { + Annotation: "go.tag optional", + GeneratedTag: "bytes,16,opt,name=OptGoTag\" query:\"OptGoTag\" form:\"OptGoTag\" json:\"json\"", + }, + { + Annotation: "go tag cover query", + GeneratedTag: "bytes,17,req,name=QueryGoTag\" json:\"QueryGoTag,required\" query:\"queryTag\"", + }, + } + + in, err := ioutil.ReadFile("./test_data/protobuf_tag_test.out") + if err != nil { + t.Fatal(err) + } + + req := &pluginpb.CodeGeneratorRequest{} + err = proto.Unmarshal(in, req) + if err != nil { + t.Fatalf("unmarshal stdin request error: %v", err) + } + + opts := protogen.Options{} + gen, err := opts.New(req) + + for _, f := range gen.Files { + if f.Proto.GetName() == "test_tag.proto" { + fileInfo := newFileInfo(f) + for _, message := range fileInfo.allMessages { + for idx, field := range message.Fields { + tags := structTags{ + {"protobuf", fieldProtobufTagValue(field)}, + } + err = injectTagsToStructTags(field.Desc, &tags, true, nil) + if err != nil { + t.Fatal(err) + } + var actualTag string + for i, tag := range tags { + if i == 0 { + actualTag = tag[0] + ":" + "\"" + tag[1] + "\"" + } else { + actualTag = actualTag + " " + tag[0] + ":" + "\"" + tag[1] + "\"" + } + } + tagList[idx].ActualTag = actualTag + + } + } + } + } + + for i := range tagList { + if !strings.Contains(tagList[i].ActualTag, tagList[i].GeneratedTag) { + t.Fatalf("expected tag: '%s', but autual tag: '%s'", tagList[i].GeneratedTag, tagList[i].ActualTag) + } + } +} diff --git a/protobuf/tags.go b/protobuf/tags.go new file mode 100644 index 0000000..26c6cc7 --- /dev/null +++ b/protobuf/tags.go @@ -0,0 +1,474 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package protobuf + +import ( + "fmt" + "sort" + "strconv" + "strings" + + "github.com/cloudwego/hertz/cmd/hz/config" + "github.com/cloudwego/hertz/cmd/hz/generator" + "github.com/cloudwego/hertz/cmd/hz/generator/model" + "github.com/cloudwego/hertz/cmd/hz/protobuf/api" + "github.com/cloudwego/hertz/cmd/hz/util" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/runtime/protoimpl" + "google.golang.org/protobuf/types/descriptorpb" +) + +var ( + jsonSnakeName = false + unsetOmitempty = false + protobufCamelJSONTagStyle = false +) + +func CheckTagOption(args *config.Argument) (ret []generator.Option) { + if args == nil { + return + } + if args.SnakeName { + jsonSnakeName = true + } + if args.UnsetOmitempty { + unsetOmitempty = true + } + if args.JSONEnumStr { + ret = append(ret, generator.OptionMarshalEnumToText) + } + if args.ProtobufCamelJSONTag { + protobufCamelJSONTagStyle = true + } + return ret +} + +func checkSnakeName(name string) string { + if jsonSnakeName { + name = util.ToSnakeCase(name) + } + return name +} + +var ( + HttpMethodOptions = map[*protoimpl.ExtensionInfo]string{ + api.E_Get: "GET", + api.E_Post: "POST", + api.E_Put: "PUT", + api.E_Patch: "PATCH", + api.E_Delete: "DELETE", + api.E_Options: "OPTIONS", + api.E_Head: "HEAD", + api.E_Any: "Any", + } + + BindingTags = map[*protoimpl.ExtensionInfo]string{ + api.E_Path: "path", + api.E_Query: "query", + api.E_Header: "header", + api.E_Cookie: "cookie", + api.E_Body: "json", + // Do not change the relative order of "api.E_Form" and "api.E_Body", so that "api.E_Form" can overwrite the form tag generated by "api.E_Body" + api.E_Form: "form", + api.E_FormCompatible: "form", + api.E_RawBody: "raw_body", + } + + ValidatorTags = map[*protoimpl.ExtensionInfo]string{api.E_Vd: "vd"} + + SerializerOptions = map[*protoimpl.ExtensionInfo]string{api.E_Serializer: "serializer"} +) + +type httpOption struct { + method string + path string +} + +type httpOptions []httpOption + +func (s httpOptions) Len() int { + return len(s) +} + +func (s httpOptions) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} + +func (s httpOptions) Less(i, j int) bool { + return s[i].method < s[j].method +} + +func getAllOptions(extensions map[*protoimpl.ExtensionInfo]string, opts ...protoreflect.ProtoMessage) map[string]interface{} { + out := map[string]interface{}{} + for _, opt := range opts { + for e, t := range extensions { + if proto.HasExtension(opt, e) { + v := proto.GetExtension(opt, e) + out[t] = v + } + } + } + return out +} + +func checkFirstOptions(extensions map[*protoimpl.ExtensionInfo]string, opts ...protoreflect.ProtoMessage) (string, interface{}) { + for _, opt := range opts { + for e, t := range extensions { + if proto.HasExtension(opt, e) { + v := proto.GetExtension(opt, e) + return t, v + } + } + } + return "", nil +} + +func checkFirstOption(ext *protoimpl.ExtensionInfo, opts ...protoreflect.ProtoMessage) interface{} { + for _, opt := range opts { + if proto.HasExtension(opt, ext) { + v := proto.GetExtension(opt, ext) + return v + } + } + return nil +} + +func checkOption(ext *protoimpl.ExtensionInfo, opts ...protoreflect.ProtoMessage) (ret []interface{}) { + for _, opt := range opts { + if proto.HasExtension(opt, ext) { + v := proto.GetExtension(opt, ext) + ret = append(ret, v) + } + } + return +} + +func tag(k string, v interface{}) model.Tag { + return model.Tag{ + Key: k, + Value: fmt.Sprintf("%v", v), + } +} + +//-----------------------------------For Compiler--------------------------- + +func defaultBindingTags(f *descriptorpb.FieldDescriptorProto) []model.Tag { + opts := f.GetOptions() + out := make([]model.Tag, 3) + if v := checkFirstOption(api.E_Body, opts); v != nil { + val := getJsonValue(f, v.(string)) + out[0] = tag("json", val) + } else { + out[0] = jsonTag(f) + } + if v := checkFirstOption(api.E_Query, opts); v != nil { + val := checkRequire(f, v.(string)) + out[1] = tag(BindingTags[api.E_Query], val) + } else { + val := checkRequire(f, checkSnakeName(f.GetName())) + out[1] = tag(BindingTags[api.E_Query], val) + } + if v := checkFirstOption(api.E_Form, opts); v != nil { + val := checkRequire(f, v.(string)) + out[2] = tag(BindingTags[api.E_Form], val) + } else { + val := checkRequire(f, checkSnakeName(f.GetName())) + out[2] = tag(BindingTags[api.E_Form], val) + } + return out +} + +func jsonTag(f *descriptorpb.FieldDescriptorProto) (ret model.Tag) { + ret.Key = "json" + ret.Value = checkSnakeName(f.GetJsonName()) + if v := checkFirstOption(api.E_JsConv, f.GetOptions()); v != nil { + ret.Value += ",string" + } else if v := checkFirstOption(api.E_JsConvCompatible, f.GetOptions()); v != nil { + ret.Value += ",string" + } + if !unsetOmitempty && f.GetLabel() == descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL { + ret.Value += ",omitempty" + } else if f.GetLabel() == descriptorpb.FieldDescriptorProto_LABEL_REQUIRED { + ret.Value += ",required" + } + return +} + +func injectTagsToModel(f *descriptorpb.FieldDescriptorProto, gf *model.Field, needDefault bool) error { + as := f.GetOptions() + + tags := gf.Tags + if tags == nil { + tags = make([]model.Tag, 0, 4) + } + + // binding tags + if needDefault { + tags = append(tags, defaultBindingTags(f)...) + } + for k, v := range BindingTags { + if vv := checkFirstOption(k, as); vv != nil { + tags.Remove(v) + if v == "json" { + vv = getJsonValue(f, vv.(string)) + } else { + vv = checkRequire(f, vv.(string)) + } + tags = append(tags, tag(v, vv)) + } + } + + // validator tags + for k, v := range ValidatorTags { + for _, vv := range checkOption(k, as) { + tags = append(tags, tag(v, vv)) + } + } + + // go.tags + for _, v := range checkOption(api.E_GoTag, as) { + gts := util.SplitGoTags(v.(string)) + for _, gt := range gts { + sp := strings.SplitN(gt, ":", 2) + if len(sp) != 2 { + return fmt.Errorf("invalid go tag: %s", v) + } + vv, err := strconv.Unquote(sp[1]) + if err != nil { + return fmt.Errorf("invalid go.tag value: %s, err: %v", sp[1], err.Error()) + } + key := sp[0] + tags.Remove(key) + tags = append(tags, model.Tag{ + Key: key, + Value: vv, + }) + } + } + + sort.Sort(tags) + gf.Tags = tags + return nil +} + +func getJsonValue(f *descriptorpb.FieldDescriptorProto, val string) string { + if v := checkFirstOption(api.E_JsConv, f.GetOptions()); v != nil { + val += ",string" + } else if v := checkFirstOption(api.E_JsConvCompatible, f.GetOptions()); v != nil { + val += ",string" + } + if !unsetOmitempty && f.GetLabel() == descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL { + val += ",omitempty" + } else if f.GetLabel() == descriptorpb.FieldDescriptorProto_LABEL_REQUIRED { + val += ",required" + } + + return val +} + +func checkRequire(f *descriptorpb.FieldDescriptorProto, val string) string { + if f.GetLabel() == descriptorpb.FieldDescriptorProto_LABEL_REQUIRED { + val += ",required" + } + + return val +} + +//-------------------------For plugin--------------------------------- + +func m2s(mt model.Tag) (ret [2]string) { + ret[0] = mt.Key + ret[1] = mt.Value + return ret +} + +func reflectJsonTag(f protoreflect.FieldDescriptor) (ret model.Tag) { + ret.Key = "json" + if protobufCamelJSONTagStyle { + ret.Value = checkSnakeName(f.JSONName()) + } else { + ret.Value = checkSnakeName(string(f.Name())) + } + if v := checkFirstOption(api.E_Body, f.Options()); v != nil { + ret.Value += ",string" + } + if descriptorpb.FieldDescriptorProto_Label(f.Cardinality()) == descriptorpb.FieldDescriptorProto_LABEL_REQUIRED { + ret.Value += ",required" + } else if !unsetOmitempty { + ret.Value += ",omitempty" + } + return +} + +func defaultBindingStructTags(f protoreflect.FieldDescriptor) []model.Tag { + opts := f.Options() + out := make([]model.Tag, 3) + bindingTags := []*protoimpl.ExtensionInfo{ + api.E_Path, + api.E_Query, + api.E_Form, + api.E_FormCompatible, + api.E_Header, + api.E_Cookie, + api.E_Body, + api.E_RawBody, + } + // If the user provides an annotation, return json tag directly + for _, tag := range bindingTags { + if vv := checkFirstOption(tag, opts); vv != nil { + out[0] = reflectJsonTag(f) + return out[:1] + } + } + + if v := checkFirstOption(api.E_Body, opts); v != nil { + val := getStructJsonValue(f, v.(string)) + out[0] = tag("json", val) + } else { + t := reflectJsonTag(f) + t.IsDefault = true + out[0] = t + } + if v := checkFirstOption(api.E_Query, opts); v != nil { + val := checkStructRequire(f, v.(string)) + out[1] = tag(BindingTags[api.E_Query], val) + } else { + val := checkStructRequire(f, checkSnakeName(string(f.Name()))) + t := tag(BindingTags[api.E_Query], val) + t.IsDefault = true + out[1] = t + } + if v := checkFirstOption(api.E_Form, opts); v != nil { + val := checkStructRequire(f, v.(string)) + out[2] = tag(BindingTags[api.E_Form], val) + } else { + if v := checkFirstOption(api.E_FormCompatible, opts); v != nil { // compatible form_compatible + val := checkStructRequire(f, v.(string)) + t := tag(BindingTags[api.E_Form], val) + t.IsDefault = true + out[2] = t + } else { + val := checkStructRequire(f, checkSnakeName(string(f.Name()))) + t := tag(BindingTags[api.E_Form], val) + t.IsDefault = true + out[2] = t + } + } + return out +} + +func injectTagsToStructTags(f protoreflect.FieldDescriptor, out *structTags, needDefault bool, rmTags RemoveTags) error { + as := f.Options() + // binding tags + tags := model.Tags(make([]model.Tag, 0, 6)) + + if needDefault { + tags = append(tags, defaultBindingStructTags(f)...) + } + for k, v := range BindingTags { + if vv := checkFirstOption(k, as); vv != nil { + tags.Remove(v) + // body annotation will generate "json" & "form" tag for protobuf + if v == "json" { + formVal := vv + vv = getStructJsonValue(f, vv.(string)) + formVal = checkStructRequire(f, formVal.(string)) + tags = append(tags, tag("form", formVal)) + } else { + vv = checkStructRequire(f, vv.(string)) + } + tags = append(tags, tag(v, vv)) + } + } + + // validator tags + for k, v := range ValidatorTags { + if vv := checkFirstOption(k, as); vv != nil { + tags = append(tags, tag(v, vv)) + } + } + + if v := checkFirstOption(api.E_GoTag, as); v != nil { + gts := util.SplitGoTags(v.(string)) + for _, gt := range gts { + sp := strings.SplitN(gt, ":", 2) + if len(sp) != 2 { + return fmt.Errorf("invalid go tag: %s", v) + } + vv, err := strconv.Unquote(sp[1]) + if err != nil { + return fmt.Errorf("invalid go.tag value: %s, err: %v", sp[1], err.Error()) + } + key := sp[0] + tags.Remove(key) + tags = append(tags, model.Tag{ + Key: key, + Value: vv, + }) + } + } + disableTag := false + if vv := checkFirstOption(api.E_None, as); vv != nil { + if strings.EqualFold(vv.(string), "true") { + disableTag = true + } + } else if vv := checkFirstOption(api.E_NoneCompatible, as); vv != nil { + if strings.EqualFold(vv.(string), "true") { + disableTag = true + } + } + for _, t := range tags { + if t.IsDefault && rmTags.Exist(t.Key) { + tags.Remove(t.Key) + } + } + // protobuf tag as first + sort.Sort(tags[1:]) + for _, t := range tags { + if disableTag { + *out = append(*out, [2]string{t.Key, "-"}) + } else { + *out = append(*out, m2s(t)) + } + } + return nil +} + +func getStructJsonValue(f protoreflect.FieldDescriptor, val string) string { + if v := checkFirstOption(api.E_JsConv, f.Options()); v != nil { + val += ",string" + } else if v := checkFirstOption(api.E_JsConvCompatible, f.Options()); v != nil { + val += ",string" + } + + if descriptorpb.FieldDescriptorProto_Label(f.Cardinality()) == descriptorpb.FieldDescriptorProto_LABEL_REQUIRED { + val += ",required" + } else if !unsetOmitempty { + val += ",omitempty" + } + + return val +} + +func checkStructRequire(f protoreflect.FieldDescriptor, val string) string { + if descriptorpb.FieldDescriptorProto_Label(f.Cardinality()) == descriptorpb.FieldDescriptorProto_LABEL_REQUIRED { + val += ",required" + } + + return val +} diff --git a/protobuf/test_data/protobuf_tag_test.out b/protobuf/test_data/protobuf_tag_test.out new file mode 100644 index 0000000..0a4dff1 Binary files /dev/null and b/protobuf/test_data/protobuf_tag_test.out differ diff --git a/protobuf/test_data/test_tag.proto b/protobuf/test_data/test_tag.proto new file mode 100644 index 0000000..96cf906 --- /dev/null +++ b/protobuf/test_data/test_tag.proto @@ -0,0 +1,32 @@ +syntax = "proto2"; + +package test; + +option go_package = "cloudwego.hertz.hz"; + +import "api.proto"; + +message MultiTagReq { + // basic feature + optional string QueryTag = 1 [(api.query)="query"]; + optional string RawBodyTag = 2 [(api.raw_body)="raw_body"]; + optional string PathTag = 3 [(api.path)="path"]; + optional string FormTag = 4 [(api.form)="form"]; + optional string CookieTag = 5 [(api.cookie)="cookie"]; + optional string HeaderTag = 6 [(api.header)="header"]; + optional string BodyTag = 7 [(api.body)="body"]; + optional string GoTag = 8 [(api.go_tag)="json:\"json\" query:\"query\" form:\"form\" header:\"header\" goTag:\"tag\""]; + optional string VdTag = 9 [(api.vd)="$!='?'"]; + optional string DefaultTag = 10; + + // optional / required + required string ReqQuery = 11 [(api.query)="query"]; + optional string OptQuery = 12 [(api.query)="query"]; + required string ReqBody = 13 [(api.body)="body"]; + optional string OptBody = 14 [(api.body)="body"]; + required string ReqGoTag = 15 [(api.go_tag)="json:\"json\""]; + optional string OptGoTag = 16 [(api.go_tag)="json:\"json\""]; + + // gotag cover feature + required string QueryGoTag = 17 [(api.query)="query", (api.go_tag)="query:\"queryTag\""]; +} diff --git a/test_hz_unix.sh b/test_hz_unix.sh new file mode 100644 index 0000000..133f90b --- /dev/null +++ b/test_hz_unix.sh @@ -0,0 +1,105 @@ +#! /usr/bin/env bash + +# const value define +moduleName="github.com/cloudwego/hertz/cmd/hz/test" +curDir=`pwd` +thriftIDL=$curDir"/testdata/thrift/psm.thrift" +protobuf2IDL=$curDir"/testdata/protobuf2/psm/psm.proto" +proto2Search=$curDir"/testdata/protobuf2" +protobuf3IDL=$curDir"/testdata/protobuf3/psm/psm.proto" +proto3Search=$curDir"/testdata/protobuf3" +protoSearch="/usr/local/include" + +judge_exit() { + code=$1 + if [ $code != 0 ]; then + exit $code + fi +} + +compile_hz() { + go build -o hz + judge_exit "$?" +} + +install_dependent_tools() { + # install thriftgo + go install github.com/cloudwego/thriftgo@latest + + # install protoc + wget https://github.com/protocolbuffers/protobuf/releases/download/v3.19.4/protoc-3.19.4-linux-x86_64.zip + unzip -d protoc-3.19.4-linux-x86_64 protoc-3.19.4-linux-x86_64.zip + cp protoc-3.19.4-linux-x86_64/bin/protoc /usr/local/bin/protoc + cp -r protoc-3.19.4-linux-x86_64/include/google /usr/local/include/google +} + +test_thrift() { + mkdir -p test + cd test + ../hz new --idl=$thriftIDL --mod=$moduleName -f --model_dir=hertz_model --handler_dir=hertz_handler --router_dir=hertz_router + judge_exit "$?" + go mod tidy && go build . + judge_exit "$?" + ../hz update --idl=$thriftIDL + judge_exit "$?" + ../hz model --idl=$thriftIDL --model_dir=hertz_model + judge_exit "$?" + ../hz client --idl=$thriftIDL --client_dir=hertz_client + judge_exit "$?" + cd .. + rm -rf test +} + +test_protobuf2() { + # test protobuf2 + mkdir -p test + cd test + ../hz new -I=$protoSearch -I=$proto2Search --idl=$protobuf2IDL --mod=$moduleName -f --model_dir=hertz_model --handler_dir=hertz_handler --router_dir=hertz_router + judge_exit "$?" + go mod tidy && go build . + judge_exit "$?" + ../hz update -I=$protoSearch -I=$proto2Search --idl=$protobuf2IDL + judge_exit "$?" + ../hz model -I=$protoSearch -I=$proto2Search --idl=$protobuf2IDL --model_dir=hertz_model + judge_exit "$?" + ../hz client -I=$protoSearch -I=$proto2Search --idl=$protobuf2IDL --client_dir=hertz_client + judge_exit "$?" + cd .. + rm -rf test +} + +test_protobuf3() { + # test protobuf2 + mkdir -p test + cd test + ../hz new -I=$protoSearch -I=$proto3Search --idl=$protobuf3IDL --mod=$moduleName -f --model_dir=hertz_model --handler_dir=hertz_handler --router_dir=hertz_router + judge_exit "$?" + go mod tidy && go build . + judge_exit "$?" + ../hz update -I=$protoSearch -I=$proto3Search --idl=$protobuf3IDL + judge_exit "$?" + ../hz model -I=$protoSearch -I=$proto3Search --idl=$protobuf3IDL --model_dir=hertz_model + judge_exit "$?" + ../hz client -I=$protoSearch -I=$proto3Search --idl=$protobuf3IDL --client_dir=hertz_client + judge_exit "$?" + cd .. + rm -rf test +} + +main() { + compile_hz + judge_exit "$?" + install_dependent_tools + judge_exit "$?" + echo "test thrift......" + test_thrift + judge_exit "$?" + echo "test protobuf2......" + test_protobuf2 + judge_exit "$?" + echo "test protobuf3......" + test_protobuf3 + judge_exit "$?" + echo "hz execute success" +} +main diff --git a/test_hz_windows.sh b/test_hz_windows.sh new file mode 100644 index 0000000..2f7aad0 --- /dev/null +++ b/test_hz_windows.sh @@ -0,0 +1,101 @@ +#! /usr/bin/env bash + +# const value define +moduleName="github.com/cloudwego/hertz/cmd/hz/test" +curDir=`pwd` +thriftIDL=$curDir"/testdata/thrift/psm.thrift" +protobuf2IDL=$curDir"/testdata/protobuf2/psm/psm.proto" +proto2Search=$curDir"/testdata/protobuf2" +protobuf3IDL=$curDir"/testdata/protobuf3/psm/psm.proto" +proto3Search=$curDir"/testdata/protobuf3" +protoSearch=$curDir"/testdata/include" + +judge_exit() { + code=$1 + if [ $code != 0 ]; then + exit $code + fi +} + +compile_hz() { + go install . + judge_exit "$?" +} + +install_dependent_tools() { + # install thriftgo + go install github.com/cloudwego/thriftgo@latest +} + +test_thrift() { + # test thrift + mkdir -p test + cd test + hz new --idl=$thriftIDL --mod=$moduleName -f --model_dir=hertz_model --handler_dir=hertz_handler --router_dir=hertz_router + judge_exit "$?" + go mod tidy && go build . + judge_exit "$?" + hz update --idl=$thriftIDL + judge_exit "$?" + hz model --idl=$thriftIDL --model_dir=hertz_model + judge_exit "$?" + hz client --idl=$thriftIDL --client_dir=hertz_client + judge_exit "$?" + cd .. + rm -rf test +} + +test_protobuf2() { + # test protobuf2 + mkdir -p test + cd test + hz new -I=$protoSearch -I=$proto2Search --idl=$protobuf2IDL --mod=$moduleName -f --model_dir=hertz_model --handler_dir=hertz_handler --router_dir=hertz_router + judge_exit "$?" + go mod tidy && go build . + judge_exit "$?" + hz update -I=$protoSearch -I=$proto2Search --idl=$protobuf2IDL + judge_exit "$?" + hz model -I=$protoSearch -I=$proto2Search --idl=$protobuf2IDL --model_dir=hertz_model + judge_exit "$?" + hz client -I=$protoSearch -I=$proto2Search --idl=$protobuf2IDL --client_dir=hertz_client + judge_exit "$?" + cd .. + rm -rf test +} + +test_protobuf3() { + # test protobuf2 + mkdir -p test + cd test + hz new -I=$protoSearch -I=$proto3Search --idl=$protobuf3IDL --mod=$moduleName -f --model_dir=hertz_model --handler_dir=hertz_handler --router_dir=hertz_router + judge_exit "$?" + go mod tidy && go build . + judge_exit "$?" + hz update -I=$protoSearch -I=$proto3Search --idl=$protobuf3IDL + judge_exit "$?" + hz model -I=$protoSearch -I=$proto3Search --idl=$protobuf3IDL --model_dir=hertz_model + judge_exit "$?" + hz client -I=$protoSearch -I=$proto3Search --idl=$protobuf3IDL --client_dir=hertz_client + judge_exit "$?" + cd .. + rm -rf test +} + +main() { + compile_hz + judge_exit "$?" + install_dependent_tools + judge_exit "$?" +# todo: add thrift test when thriftgo fixed windows + echo "test thrift......" + test_thrift + judge_exit "$?" + echo "test protobuf2......" + test_protobuf2 + judge_exit "$?" + echo "test protobuf3......" + test_protobuf3 + judge_exit "$?" + echo "hz execute success" +} +main diff --git a/testdata/protobuf2/api.proto b/testdata/protobuf2/api.proto new file mode 100644 index 0000000..9081737 --- /dev/null +++ b/testdata/protobuf2/api.proto @@ -0,0 +1,65 @@ +syntax = "proto2"; + +package api; + +import "google/protobuf/descriptor.proto"; + +option go_package = "github.com/cloudwego/hertz/cmd/hz/test/hertz_model/api"; + +extend google.protobuf.FieldOptions { + optional string raw_body = 50101; + optional string query = 50102; + optional string header = 50103; + optional string cookie = 50104; + optional string body = 50105; + optional string path = 50106; + optional string vd = 50107; + optional string form = 50108; + optional string js_conv = 50109; + optional string file_name = 50110; + optional string none = 50111; + + // 50131~50160 used to extend field option by hz + optional string form_compatible = 50131; + optional string js_conv_compatible = 50132; + optional string file_name_compatible = 50133; + optional string none_compatible = 50134; + + optional string go_tag = 51001; +} + +extend google.protobuf.MethodOptions { + optional string get = 50201; + optional string post = 50202; + optional string put = 50203; + optional string delete = 50204; + optional string patch = 50205; + optional string options = 50206; + optional string head = 50207; + optional string any = 50208; + optional string gen_path = 50301; // The path specified by the user when the client code is generated, with a higher priority than api_version + optional string api_version = 50302; // Specify the value of the :version variable in path when the client code is generated + optional string tag = 50303; // rpc tag, can be multiple, separated by commas + optional string name = 50304; // Name of rpc + optional string api_level = 50305; // Interface Level + optional string serializer = 50306; // Serialization method + optional string param = 50307; // Whether client requests take public parameters + optional string baseurl = 50308; // Baseurl used in ttnet routing + optional string handler_path = 50309; // handler_path specifies the path to generate the method + + // 50331~50360 used to extend method option by hz + optional string handler_path_compatible = 50331; // handler_path specifies the path to generate the method +} + +extend google.protobuf.EnumValueOptions { + optional int32 http_code = 50401; + + // 50431~50460 used to extend enum option by hz +} + +extend google.protobuf.ServiceOptions { + optional string base_domain = 50402; + + // 50731~50760 used to extend service option by hz + optional string base_domain_compatible = 50731; +} \ No newline at end of file diff --git a/testdata/protobuf2/other/other.proto b/testdata/protobuf2/other/other.proto new file mode 100644 index 0000000..d5d883d --- /dev/null +++ b/testdata/protobuf2/other/other.proto @@ -0,0 +1,12 @@ +syntax = "proto2"; + +package hertz.other; + +import "other/other_base.proto"; + +option go_package = "github.com/cloudwego/hertz/cmd/hz/test/hertz_model/other"; + +message OtherType { + optional string IsBaseString = 1; + optional OtherBaseType IsOtherBaseType = 2; +} \ No newline at end of file diff --git a/testdata/protobuf2/other/other_base.proto b/testdata/protobuf2/other/other_base.proto new file mode 100644 index 0000000..32c4b32 --- /dev/null +++ b/testdata/protobuf2/other/other_base.proto @@ -0,0 +1,9 @@ +syntax = "proto2"; + +package hertz.other; + +option go_package = "github.com/cloudwego/hertz/cmd/hz/test/hertz_model/other"; + +message OtherBaseType { + optional string IsOtherBaseTypeString = 1; +} \ No newline at end of file diff --git a/testdata/protobuf2/psm/base.proto b/testdata/protobuf2/psm/base.proto new file mode 100644 index 0000000..6489492 --- /dev/null +++ b/testdata/protobuf2/psm/base.proto @@ -0,0 +1,14 @@ +syntax = "proto2"; + +package base; + +option go_package = "github.com/cloudwego/hertz/cmd/hz/test/hertz_model/psm"; + +message Base { + optional string IsBaseString = 1; +} + +enum BaseEnumType { + TWEET = 0; + RETWEET = 1; +} \ No newline at end of file diff --git a/testdata/protobuf2/psm/psm.proto b/testdata/protobuf2/psm/psm.proto new file mode 100644 index 0000000..b885366 --- /dev/null +++ b/testdata/protobuf2/psm/psm.proto @@ -0,0 +1,155 @@ +syntax = "proto2"; + +package psm; + +option go_package = "github.com/cloudwego/hertz/cmd/hz/test/hertz_model/psm"; + +import "api.proto"; +import "base.proto"; +import "other/other.proto"; + +enum EnumType { + TWEET = 0; + RETWEET = 1; +} +message UnusedMessageType { + optional string IsUnusedMessageType = 1; +} + +message BaseType { + optional base.Base IsBaseType = 1; +} + +message MultiTypeReq { + // basic type (leading comments) + optional bool IsBoolOpt = 1; + required bool IsBoolReq = 2; + optional int32 IsInt32Opt = 3; + required int32 IsInt32Req = 4; + optional int64 IsInt64Opt = 5; + optional uint32 IsUInt32Opt = 6; + optional uint64 IsUInt64Opt = 7; + optional sint32 IsSInt32Opt = 8; + optional sint64 IsSInt64Opt = 9; + optional fixed32 IsFix32Opt = 10; + optional fixed64 IsFix64Opt = 11; + optional sfixed32 IsSFix32Opt = 12; + optional sfixed64 IsSFix64Opt = 13; + optional double IsDoubleOpt = 14; + required double IsDoubleReq = 15; + optional float IsFloatOpt = 16; + optional string IsStringOpt = 17; + required string IsStringReq = 18; + optional bytes IsBytesOpt = 19; + optional bytes IsBytesReq = 20; + + // slice + repeated string IsRepeatedString = 21; + repeated BaseType IsRepeatedBaseType = 22; + + // map + map IsStringMap = 23; + map IsBaseTypeMap = 24; + + // oneof + // multiple comments + oneof TestOneof { + string IsOneofString = 25; + BaseType IsOneofBaseType = 26; + int32 IsOneofInt = 100; + bool IsOneofBool = 101; + double IsOneoDouble = 102; + bytes IsOneofBytes = 103; + } + + // this is oneof2, one field in oneof + oneof TestOneof2 { + string IsOneof2String = 104; + } + + message NestedMessageType { + optional string IsNestedString = 1; + optional BaseType IsNestedBaseType = 2; + repeated BaseType IsNestedRepeatedBaseType = 3; + // nested oneof + oneof NestedMsgOneof { + string IsNestedMsgOneofString = 4; + EnumType IsNestedMsgOneofEnumType = 5; + } + } + // nested message + optional NestedMessageType IsNestedType = 27; + + // other dependency + optional base.Base IsCurrentPackageBase = 28; + optional hertz.other.OtherType IsOtherType = 29; + + // enum + optional EnumType IsEnumTypeOpt = 30; + required EnumType IsEnumTypeReq = 31; + repeated EnumType IsEnumTypeList = 32; + optional base.BaseEnumType IsBaseEnumType = 33; +} + +message MultiTagReq { + optional string QueryTag = 1 [(api.query) = "query", (api.none) = "true"]; + optional string RawBodyTag = 2 [(api.raw_body) = "raw_body"]; + optional string CookieTag = 3 [(api.cookie) = "cookie"]; + optional string BodyTag = 4 [(api.body) = "body"]; + optional string PathTag = 5 [(api.path) = "path"]; + optional string VdTag = 6 [(api.vd) = "$!='?'"]; + optional string FormTag = 7 [(api.form) = "form"]; + optional string DefaultTag = 8 [(api.go_tag) = "FFF:\"fff\" json:\"json\""]; +} + +message CompatibleAnnoReq { + optional string FormCompatibleTag = 1 [(api.form_compatible) = "form"]; + optional string FilenameCompatibleTag = 2 [(api.file_name_compatible) = "file_name"]; + optional string NoneCompatibleTag = 3 [(api.none_compatible) = "true"]; + optional string JsConvCompatibleTag = 4 [(api.js_conv_compatible) = "true"]; +} + +message Resp { + optional string Resp = 1; +} + +message MultiNameStyleMessage { + optional string hertz = 1; + optional string Hertz = 2; + optional string hertz_demo = 3; + optional string hertz_demo_idl = 4; + optional string hertz_Idl = 5; + optional string hertzDemo = 6; + optional string h = 7; + optional string H = 8; + optional string hertz_ = 9; +} + +service Hertz { + rpc Method1(MultiTypeReq) returns(Resp) { + option (api.get) = "/company/department/group/user:id/name"; + } + rpc Method2(MultiTypeReq) returns(Resp) { + option (api.post) = "/company/department/group/user:id/sex"; + } + rpc Method3(MultiTypeReq) returns(Resp) { + option (api.put) = "/company/department/group/user:id/number"; + } + rpc Method4(MultiTypeReq) returns(Resp) { + option (api.delete) = "/company/department/group/user:id/age"; + } + + + rpc Method5(MultiTagReq) returns(Resp) { + option (api.options) = "/school/class/student/name"; + } + rpc Method6(MultiTagReq) returns(Resp) { + option (api.head) = "/school/class/student/number"; + } + rpc Method7(MultiTagReq) returns(Resp) { + option (api.patch) = "/school/class/student/sex"; + } + rpc Method8(MultiTagReq) returns(Resp) { + option (api.any) = "/school/class/student/grade/*subjects"; + } +} \ No newline at end of file diff --git a/testdata/protobuf3/api.proto b/testdata/protobuf3/api.proto new file mode 100644 index 0000000..9081737 --- /dev/null +++ b/testdata/protobuf3/api.proto @@ -0,0 +1,65 @@ +syntax = "proto2"; + +package api; + +import "google/protobuf/descriptor.proto"; + +option go_package = "github.com/cloudwego/hertz/cmd/hz/test/hertz_model/api"; + +extend google.protobuf.FieldOptions { + optional string raw_body = 50101; + optional string query = 50102; + optional string header = 50103; + optional string cookie = 50104; + optional string body = 50105; + optional string path = 50106; + optional string vd = 50107; + optional string form = 50108; + optional string js_conv = 50109; + optional string file_name = 50110; + optional string none = 50111; + + // 50131~50160 used to extend field option by hz + optional string form_compatible = 50131; + optional string js_conv_compatible = 50132; + optional string file_name_compatible = 50133; + optional string none_compatible = 50134; + + optional string go_tag = 51001; +} + +extend google.protobuf.MethodOptions { + optional string get = 50201; + optional string post = 50202; + optional string put = 50203; + optional string delete = 50204; + optional string patch = 50205; + optional string options = 50206; + optional string head = 50207; + optional string any = 50208; + optional string gen_path = 50301; // The path specified by the user when the client code is generated, with a higher priority than api_version + optional string api_version = 50302; // Specify the value of the :version variable in path when the client code is generated + optional string tag = 50303; // rpc tag, can be multiple, separated by commas + optional string name = 50304; // Name of rpc + optional string api_level = 50305; // Interface Level + optional string serializer = 50306; // Serialization method + optional string param = 50307; // Whether client requests take public parameters + optional string baseurl = 50308; // Baseurl used in ttnet routing + optional string handler_path = 50309; // handler_path specifies the path to generate the method + + // 50331~50360 used to extend method option by hz + optional string handler_path_compatible = 50331; // handler_path specifies the path to generate the method +} + +extend google.protobuf.EnumValueOptions { + optional int32 http_code = 50401; + + // 50431~50460 used to extend enum option by hz +} + +extend google.protobuf.ServiceOptions { + optional string base_domain = 50402; + + // 50731~50760 used to extend service option by hz + optional string base_domain_compatible = 50731; +} \ No newline at end of file diff --git a/testdata/protobuf3/other/other.proto b/testdata/protobuf3/other/other.proto new file mode 100644 index 0000000..d5d883d --- /dev/null +++ b/testdata/protobuf3/other/other.proto @@ -0,0 +1,12 @@ +syntax = "proto2"; + +package hertz.other; + +import "other/other_base.proto"; + +option go_package = "github.com/cloudwego/hertz/cmd/hz/test/hertz_model/other"; + +message OtherType { + optional string IsBaseString = 1; + optional OtherBaseType IsOtherBaseType = 2; +} \ No newline at end of file diff --git a/testdata/protobuf3/other/other_base.proto b/testdata/protobuf3/other/other_base.proto new file mode 100644 index 0000000..32c4b32 --- /dev/null +++ b/testdata/protobuf3/other/other_base.proto @@ -0,0 +1,9 @@ +syntax = "proto2"; + +package hertz.other; + +option go_package = "github.com/cloudwego/hertz/cmd/hz/test/hertz_model/other"; + +message OtherBaseType { + optional string IsOtherBaseTypeString = 1; +} \ No newline at end of file diff --git a/testdata/protobuf3/psm/base.proto b/testdata/protobuf3/psm/base.proto new file mode 100644 index 0000000..ea8891a --- /dev/null +++ b/testdata/protobuf3/psm/base.proto @@ -0,0 +1,9 @@ +syntax = "proto2"; + +package base; + +option go_package = "github.com/cloudwego/hertz/cmd/hz/test/hertz_model/psm"; + +message Base { + optional string IsBaseString = 1; +} \ No newline at end of file diff --git a/testdata/protobuf3/psm/psm.proto b/testdata/protobuf3/psm/psm.proto new file mode 100644 index 0000000..3dfeaf5 --- /dev/null +++ b/testdata/protobuf3/psm/psm.proto @@ -0,0 +1,133 @@ +syntax = "proto3"; + +package psm; + +option go_package = "github.com/cloudwego/hertz/cmd/hz/test/hertz_model/psm"; + +import "api.proto"; +import "base.proto"; +import "other/other.proto"; + +enum EnumType { + TWEET = 0; + RETWEET = 1; +} +message UnusedMessageType { + optional string IsUnusedMessageType = 1; +} + +message BaseType { + optional base.Base IsBaseType = 1; +} + +message MultiTypeReq { + // basic type (leading comments) + optional bool IsBoolOpt = 1; + optional int32 IsInt32Opt = 3; + int64 IsInt64Default = 5; + optional uint32 IsUInt32Opt = 6; + uint64 IsUInt64Default = 7; + optional sint32 IsSInt32Opt = 8; + sint64 IsSInt64Default = 9; + optional fixed32 IsFix32Opt = 10; + optional fixed64 IsFix64Opt = 11; + optional sfixed32 IsSFix32Opt = 12; + optional sfixed64 IsSFix64Opt = 13; + optional double IsDoubleOpt = 14; + optional float IsFloatOpt = 16; + optional string IsStringOpt = 17; + optional bytes IsBytesOpt = 19; + bytes IsBytesDefault = 20; + + // slice + repeated string IsRepeatedString = 21; + repeated BaseType IsRepeatedBaseType = 22; + + // map + map IsStringMap = 23; + map IsBaseTypeMap = 24; + + // oneof + oneof TestOneof { + string IsOneofString = 25; + BaseType IsOneofBaseTypeString = 26; + } + + oneof TestOneof2 { + string IsOneofString2 = 100; + } + + // nested message + message NestedMessageType { + oneof NestedOneof { + string YYY = 4; + string GGG = 5; + } + optional string IsNestedString = 1; + optional BaseType IsNestedBaseType = 2; + repeated BaseType IsNestedRepeatedBaseType = 3; + } + optional NestedMessageType IsNestedType = 27; + + // other dependency + optional base.Base IsCurrentPackageBase = 28; + optional hertz.other.OtherType IsOtherType = 29; + + // enum + optional EnumType IsEnumTypeOpt = 30; + EnumType IsEnumTypeDefault = 31; +} + +message MultiTagReq { + optional string QueryTag = 1 [(api.query) = "query", (api.none) = "true"]; + optional string RawBodyTag = 2 [(api.raw_body)="raw_body"]; + optional string CookieTag = 3 [(api.cookie)="cookie"]; + optional string BodyTag = 4 [(api.body)="body"]; + optional string PathTag = 5 [(api.path)="path"]; + optional string VdTag = 6 [(api.vd)="$!='?'"]; + optional string DefaultTag = 7; + oneof TestOneof { + string IsOneofString = 25; + BaseType IsOneofBaseTypeString = 26; + } +} + +message CompatibleAnnoReq { + optional string FormCompatibleTag = 1 [(api.form_compatible) = "form"]; + optional string FilenameCompatibleTag = 2 [(api.file_name_compatible) = "file_name"]; + optional string NoneCompatibleTag = 3 [(api.none_compatible) = "true"]; + optional string JsConvCompatibleTag = 4 [(api.js_conv_compatible) = "true"]; +} + +message Resp { + optional string Resp = 1; +} + +service Hertz { + rpc Method1(MultiTypeReq) returns(Resp) { + option (api.get)="/company/department/group/user:id/name"; + } + rpc Method2(MultiTypeReq) returns(Resp) { + option (api.post)="/company/department/group/user:id/sex"; + } + rpc Method3(MultiTypeReq) returns(Resp) { + option (api.put)="/company/department/group/user:id/number"; + } + rpc Method4(MultiTypeReq) returns(Resp) { + option (api.delete)="/company/department/group/user:id/age"; + } + + + rpc Method5(MultiTagReq) returns(Resp) { + option (api.options)="/school/class/student/name"; + } + rpc Method6(MultiTagReq) returns(Resp) { + option (api.head)="/school/class/student/number"; + } + rpc Method7(MultiTagReq) returns(Resp) { + option (api.patch)="/school/class/student/sex"; + } + rpc Method8(MultiTagReq) returns(Resp) { + option (api.any)="/school/class/student/grade/*subjects"; + } +} \ No newline at end of file diff --git a/testdata/thrift/common.thrift b/testdata/thrift/common.thrift new file mode 100644 index 0000000..4efbb82 --- /dev/null +++ b/testdata/thrift/common.thrift @@ -0,0 +1,13 @@ +namespace go toutiao.middleware.hertz + +struct CommonType { + 1: required string IsCommonString; + 2: optional string TTT; + 3: required bool HHH; + 4: required Base GGG; +} + +struct Base { + 1: optional string AAA; + 2: optional i32 BBB; +} \ No newline at end of file diff --git a/testdata/thrift/data/basic_data.thrift b/testdata/thrift/data/basic_data.thrift new file mode 100644 index 0000000..70f6b5a --- /dev/null +++ b/testdata/thrift/data/basic_data.thrift @@ -0,0 +1,5 @@ +namespace go toutiao.middleware.hertz_data + +struct BasicDataType { + 1: optional string IsBasicDataString; +} \ No newline at end of file diff --git a/testdata/thrift/data/data.thrift b/testdata/thrift/data/data.thrift new file mode 100644 index 0000000..a0b7615 --- /dev/null +++ b/testdata/thrift/data/data.thrift @@ -0,0 +1,7 @@ +include "basic_data.thrift" + +namespace go toutiao.middleware.hertz_data + +struct DataType { + 1: optional basic_data.BasicDataType IsDataString; +} \ No newline at end of file diff --git a/testdata/thrift/psm.thrift b/testdata/thrift/psm.thrift new file mode 100644 index 0000000..816e77a --- /dev/null +++ b/testdata/thrift/psm.thrift @@ -0,0 +1,122 @@ +include "common.thrift" +include "data/data.thrift" + +namespace go toutiao.middleware.hertz + +const string STRING_CONST = "hertz"; + +enum EnumType { + TWEET, + RETWEET = 2, +} + +typedef i32 MyInteger + +struct BaseType { + 1: string GoTag = "test" (go.tag="json:\"go\" goTag:\"tag\""); + 2: optional string IsBaseString = "test"; + 3: optional common.CommonType IsDepCommonType = {"IsCommonString":"test", "TTT":"test", "HHH":true, "GGG": {"AAA":"test","BBB":32}}; + 4: optional EnumType IsBaseTypeEnum = 1; +} + +typedef common.CommonType FFF + +typedef BaseType MyBaseType + +struct MultiTypeReq { + // basic type (leading comments) + 1: optional bool IsBoolOpt = true; // trailing comments + 2: required bool IsBoolReq; + 3: optional byte IsByteOpt = 8; + 4: required byte IsByteReq; + //5: optional i8 IsI8Opt; // unsupported i8, suggest byte + //6: required i8 IsI8Req = 5; // default + 7: optional i16 IsI16Opt = 16; + 8: optional i32 IsI32Opt; + 9: optional i64 IsI64Opt; + 10: optional double IsDoubleOpt; + 11: required double IsDoubleReq; + 12: optional string IsStringOpt = "test"; + 13: required string IsStringReq; + + 14: optional list IsList; + 22: required list IsListReq; + 15: optional set IsSet; + 16: optional map IsMap; + 21: optional map IsStructMap; + + // struct type + 17: optional BaseType IsBaseType; // use struct name + 18: optional MyBaseType IsMyBaseType; // use typedef for struct + 19: optional common.CommonType IsCommonType = {"IsCommonString": "fffff"}; + 20: optional data.DataType IsDataType; // multi-dependent struct +} + +typedef data.DataType IsMyDataType + +struct MultiTagReq { + 1: string QueryTag (api.query="query"); + 2: string RawBodyTag (api.raw_body="raw_body"); + 3: string PathTag (api.path="path"); + 4: string FormTag (api.form="form"); + 5: string CookieTag (api.cookie="cookie"); + 6: string HeaderTag (api.header="header"); + 7: string ProtobufTag (api.protobuf="protobuf"); + 8: string BodyTag (api.body="body"); + 9: string GoTag (go.tag="json:\"go\" goTag:\"tag\""); + 10: string VdTag (api.vd="$!='?'"); + 11: string DefaultTag; +} + +struct Resp { + 1: string Resp = "this is Resp"; +} + +struct MultiNameStyleReq { + 1: optional string hertz; + 2: optional string Hertz; + 3: optional string hertz_demo; + 4: optional string hertz_demo_idl; + 5: optional string hertz_Idl; + 6: optional string hertzDemo; + 7: optional string h; + 8: optional string H; + 9: optional string hertz_; +} + +struct MultiDefaultReq { + 1: optional bool IsBoolOpt = true; + 2: required bool IsBoolReq = false; + 3: optional i32 IsI32Opt = 32; + 4: required i32 IsI32Req = 32; + 5: optional string IsStringOpt = "test"; + 6: required string IsStringReq = "test"; + + 14: optional list IsListOpt = ["test", "ttt", "sdsds"]; + 22: required list IsListReq = ["test", "ttt", "sdsds"]; + 15: optional set IsSet = ["test", "ttt", "sdsds"]; + 16: optional map IsMapOpt = {"test": "ttt", "ttt": "lll"}; + 17: required map IsMapReq = {"test": "ttt", "ttt": "lll"}; + 21: optional map IsStructMapOpt = {"test": {"GoTag":"fff", "IsBaseTypeEnum":1, "IsBaseString":"ddd", "IsDepCommonType": {"IsCommonString":"fffffff", "TTT":"ttt", "HHH":true, "GGG": {"AAA":"test","BBB":32}}}}; + 25: required map IsStructMapReq = {"test": {"GoTag":"fff", "IsBaseTypeEnum":1, "IsBaseString":"ddd", "IsDepCommonType": {"IsCommonString":"fffffff", "TTT":"ttt", "HHH":true, "GGG": {"AAA":"test","BBB":32}}}}; + + 23: optional common.CommonType IsDepCommonTypeOpt = {"IsCommonString":"fffffff", "TTT":"ttt", "HHH":true, "GGG": {"AAA":"test","BBB":32}}; + 24: required common.CommonType IsDepCommonTypeReq = {"IsCommonString":"fffffff", "TTT":"ttt", "HHH":true, "GGG": {"AAA":"test","BBB":32}}; +} + +typedef map IsTypedefContainer + +service Hertz { + Resp Method1(1: MultiTypeReq request) (api.get="/company/department/group/user:id/name", api.handler_path="v1"); + Resp Method2(1: MultiTagReq request) (api.post="/company/department/group/user:id/sex", api.handler_path="v1"); + Resp Method3(1: BaseType request) (api.put="/company/department/group/user:id/number", api.handler_path="v1"); + Resp Method4(1: data.DataType request) (api.delete="/company/department/group/user:id/age", api.handler_path="v1"); + + Resp Method5(1: MultiTypeReq request) (api.options="/school/class/student/name", api.handler_path="v2"); + Resp Method6(1: MultiTagReq request) (api.head="/school/class/student/number", api.handler_path="v2"); + Resp Method7(1: MultiTagReq request) (api.patch="/school/class/student/sex", api.handler_path="v2"); + Resp Method8(1: BaseType request) (api.any="/school/class/student/grade/*subjects", api.handler_path="v2"); + + Resp Method9(1: IsTypedefContainer request) (api.get="/typedef/container", api.handler_path="v2"); + Resp Method10(1: map request) (api.get="/container", api.handler_path="v2"); +} \ No newline at end of file diff --git a/thrift/ast.go b/thrift/ast.go new file mode 100644 index 0000000..5005a20 --- /dev/null +++ b/thrift/ast.go @@ -0,0 +1,916 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "fmt" + "sort" + "strings" + + "github.com/cloudwego/hertz/cmd/hz/config" + "github.com/cloudwego/hertz/cmd/hz/generator" + "github.com/cloudwego/hertz/cmd/hz/generator/model" + "github.com/cloudwego/hertz/cmd/hz/meta" + "github.com/cloudwego/hertz/cmd/hz/util" + "github.com/cloudwego/hertz/cmd/hz/util/logs" + "github.com/cloudwego/thriftgo/generator/golang" + "github.com/cloudwego/thriftgo/generator/golang/styles" + "github.com/cloudwego/thriftgo/parser" + "github.com/cloudwego/thriftgo/semantic" +) + +/*---------------------------Import-----------------------------*/ + +func getGoPackage(ast *parser.Thrift, pkgMap map[string]string) string { + filePackage := ast.GetFilename() + if opt, ok := pkgMap[filePackage]; ok { + return opt + } else { + goPackage := ast.GetNamespaceOrReferenceName("go") + if goPackage != "" { + return util.SplitPackage(goPackage, "") + } + // If namespace is not declared, the file name (without the extension) is used as the package name + return util.SplitPackage(filePackage, ".thrift") + } +} + +/*---------------------------Service-----------------------------*/ + +func astToService(ast *parser.Thrift, resolver *Resolver, args *config.Argument) ([]*generator.Service, error) { + ss := ast.GetServices() + out := make([]*generator.Service, 0, len(ss)) + var models model.Models + extendServices := getExtendServices(ast) + for _, s := range ss { + // if the service is extended, it is not processed + if extendServices.exist(s.Name) && args.EnableExtends { + logs.Debugf("%s is extended, so skip it\n", s.Name) + continue + } + + resolver.ExportReferred(true, false) + service := &generator.Service{ + Name: s.GetName(), + } + service.BaseDomain = "" + domainAnno := getAnnotation(s.Annotations, ApiBaseDomain) + if len(domainAnno) == 1 { + if args.CmdType == meta.CmdClient { + service.BaseDomain = domainAnno[0] + } + } + service.ServiceGroup = "" + groupAnno := getAnnotation(s.Annotations, ApiServiceGroup) + if len(groupAnno) == 1 { + if args.CmdType != meta.CmdClient { + service.ServiceGroup = groupAnno[0] + } + } + service.ServiceGenDir = "" + serviceGenDirAnno := getAnnotation(s.Annotations, ApiServiceGenDir) + if len(serviceGenDirAnno) == 1 { + if args.CmdType != meta.CmdClient { + service.ServiceGenDir = serviceGenDirAnno[0] + } + } + ms := s.GetFunctions() + if len(s.Extends) != 0 && args.EnableExtends { + // all the services that are extended to the current service + extendsFuncs, err := getAllExtendFunction(s, ast, resolver, args) + if err != nil { + return nil, fmt.Errorf("parser extend function failed, err=%v", err) + } + ms = append(ms, extendsFuncs...) + } + methods := make([]*generator.HttpMethod, 0, len(ms)) + clientMethods := make([]*generator.ClientMethod, 0, len(ms)) + servicePathAnno := getAnnotation(s.Annotations, ApiServicePath) + servicePath := "" + if len(servicePathAnno) > 0 { + servicePath = servicePathAnno[0] + } + for _, m := range ms { + rs := getAnnotations(m.Annotations, HttpMethodAnnotations) + if len(rs) == 0 { + continue + } + httpAnnos := httpAnnotations{} + for k, v := range rs { + httpAnnos = append(httpAnnos, httpAnnotation{ + method: k, + path: v, + }) + } + // turn the map into a slice and sort it to make sure getting the results in the same order every time + sort.Sort(httpAnnos) + handlerOutDir := servicePath + genPaths := getAnnotation(m.Annotations, ApiGenPath) + if len(genPaths) == 1 { + handlerOutDir = genPaths[0] + } else if len(genPaths) > 0 { + return nil, fmt.Errorf("too many 'api.handler_path' for %s", m.Name) + } + + hmethod, path := httpAnnos[0].method, httpAnnos[0].path + if len(path) == 0 || path[0] == "" { + return nil, fmt.Errorf("invalid api.%s for %s.%s: %s", hmethod, s.Name, m.Name, path) + } + + var reqName, reqRawName, reqPackage string + if len(m.Arguments) >= 1 { + if len(m.Arguments) > 1 { + logs.Warnf("function '%s' has more than one argument, but only the first can be used in hertz now", m.GetName()) + } + var err error + reqName, err = resolver.ResolveTypeName(m.Arguments[0].GetType()) + if err != nil { + return nil, err + } + if strings.Contains(reqName, ".") && !m.Arguments[0].GetType().Category.IsContainerType() { + // If reqName contains "." , then it must be of the form "pkg.name". + // so reqRawName='name', reqPackage='pkg' + names := strings.Split(reqName, ".") + if len(names) != 2 { + return nil, fmt.Errorf("request name: %s is wrong", reqName) + } + reqRawName = names[1] + reqPackage = names[0] + } + } + var respName, respRawName, respPackage string + if !m.Oneway { + var err error + respName, err = resolver.ResolveTypeName(m.GetFunctionType()) + if err != nil { + return nil, err + } + if strings.Contains(respName, ".") && !m.GetFunctionType().Category.IsContainerType() { + names := strings.Split(respName, ".") + if len(names) != 2 { + return nil, fmt.Errorf("response name: %s is wrong", respName) + } + // If respName contains "." , then it must be of the form "pkg.name". + // so respRawName='name', respPackage='pkg' + respRawName = names[1] + respPackage = names[0] + } + } + + sr, _ := util.GetFirstKV(getAnnotations(m.Annotations, SerializerTags)) + method := &generator.HttpMethod{ + Name: util.CamelString(m.GetName()), + HTTPMethod: hmethod, + RequestTypeName: reqName, + RequestTypeRawName: reqRawName, + RequestTypePackage: reqPackage, + ReturnTypeName: respName, + ReturnTypeRawName: respRawName, + ReturnTypePackage: respPackage, + Path: path[0], + Serializer: sr, + OutputDir: handlerOutDir, + GenHandler: true, + // Annotations: m.Annotations, + } + refs := resolver.ExportReferred(false, true) + method.Models = make(map[string]*model.Model, len(refs)) + for _, ref := range refs { + if v, ok := method.Models[ref.Model.PackageName]; ok && (v.Package != ref.Model.Package) { + return nil, fmt.Errorf("Package name: %s redeclared in %s and %s ", ref.Model.PackageName, v.Package, ref.Model.Package) + } + method.Models[ref.Model.PackageName] = ref.Model + } + models.MergeMap(method.Models) + methods = append(methods, method) + for idx, anno := range httpAnnos { + for i := 0; i < len(anno.path); i++ { + if idx == 0 && i == 0 { // idx==0 && i==0 has been added above + continue + } + newMethod, err := newHTTPMethod(s, m, method, i, anno) + if err != nil { + return nil, err + } + methods = append(methods, newMethod) + } + } + if args.CmdType == meta.CmdClient { + clientMethod := &generator.ClientMethod{} + clientMethod.HttpMethod = method + rt, err := resolver.ResolveIdentifier(m.Arguments[0].GetType().GetName()) + if err != nil { + return nil, err + } + err = parseAnnotationToClient(clientMethod, m.Arguments[0].GetType(), rt) + if err != nil { + return nil, err + } + clientMethods = append(clientMethods, clientMethod) + } + } + + service.ClientMethods = clientMethods + service.Methods = methods + service.Models = models + out = append(out, service) + } + return out, nil +} + +func newHTTPMethod(s *parser.Service, m *parser.Function, method *generator.HttpMethod, i int, anno httpAnnotation) (*generator.HttpMethod, error) { + newMethod := *method + hmethod, path := anno.method, anno.path + if path[i] == "" { + return nil, fmt.Errorf("invalid api.%s for %s.%s: %s", hmethod, s.Name, m.Name, path[i]) + } + newMethod.HTTPMethod = hmethod + newMethod.Path = path[i] + newMethod.GenHandler = false + return &newMethod, nil +} + +func parseAnnotationToClient(clientMethod *generator.ClientMethod, p *parser.Type, symbol ResolvedSymbol) error { + if p == nil { + return fmt.Errorf("get type failed for parse annotatoon to client") + } + typeName := p.GetName() + if strings.Contains(typeName, ".") { + ret := strings.Split(typeName, ".") + typeName = ret[len(ret)-1] + } + scope, err := golang.BuildScope(thriftgoUtil, symbol.Scope) + if err != nil { + return fmt.Errorf("can not build scope for %s", p.Name) + } + thriftgoUtil.SetRootScope(scope) + st := scope.StructLike(typeName) + if st == nil { + logs.Infof("the type '%s' for method '%s' is base type, so skip parse client info\n") + return nil + } + var ( + hasBodyAnnotation bool + hasFormAnnotation bool + ) + for _, field := range st.Fields() { + hasAnnotation := false + isStringFieldType := false + if field.GetType().String() == "string" { + isStringFieldType = true + } + if anno := getAnnotation(field.Annotations, AnnotationQuery); len(anno) > 0 { + hasAnnotation = true + query := checkSnakeName(anno[0]) + clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", query, field.GoName().String()) + } + + if anno := getAnnotation(field.Annotations, AnnotationPath); len(anno) > 0 { + hasAnnotation = true + path := checkSnakeName(anno[0]) + if isStringFieldType { + clientMethod.PathParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", path, field.GoName().String()) + } else { + clientMethod.PathParamsCode += fmt.Sprintf("%q: fmt.Sprint(req.Get%s()),\n", path, field.GoName().String()) + } + } + + if anno := getAnnotation(field.Annotations, AnnotationHeader); len(anno) > 0 { + hasAnnotation = true + header := checkSnakeName(anno[0]) + if isStringFieldType { + clientMethod.HeaderParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", header, field.GoName().String()) + } else { + clientMethod.HeaderParamsCode += fmt.Sprintf("%q: fmt.Sprint(req.Get%s()),\n", header, field.GoName().String()) + } + } + + if anno := getAnnotation(field.Annotations, AnnotationForm); len(anno) > 0 { + hasAnnotation = true + form := checkSnakeName(anno[0]) + hasFormAnnotation = true + if isStringFieldType { + clientMethod.FormValueCode += fmt.Sprintf("%q: req.Get%s(),\n", form, field.GoName().String()) + } else { + clientMethod.FormValueCode += fmt.Sprintf("%q: fmt.Sprint(req.Get%s()),\n", form, field.GoName().String()) + } + } + + if anno := getAnnotation(field.Annotations, AnnotationBody); len(anno) > 0 { + hasAnnotation = true + hasBodyAnnotation = true + } + + if anno := getAnnotation(field.Annotations, AnnotationFileName); len(anno) > 0 { + hasAnnotation = true + fileName := checkSnakeName(anno[0]) + hasFormAnnotation = true + clientMethod.FormFileCode += fmt.Sprintf("%q: req.Get%s(),\n", fileName, field.GoName().String()) + } + if !hasAnnotation && strings.EqualFold(clientMethod.HTTPMethod, "get") { + clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", checkSnakeName(field.GetName()), field.GoName().String()) + } + } + clientMethod.BodyParamsCode = meta.SetBodyParam + if hasBodyAnnotation && hasFormAnnotation { + clientMethod.FormValueCode = "" + clientMethod.FormFileCode = "" + } + if !hasBodyAnnotation && hasFormAnnotation { + clientMethod.BodyParamsCode = "" + } + + return nil +} + +type extendServiceList []string + +func (svr extendServiceList) exist(serviceName string) bool { + for _, s := range svr { + if s == serviceName { + return true + } + } + return false +} + +func getExtendServices(ast *parser.Thrift) (res extendServiceList) { + for a := range ast.DepthFirstSearch() { + for _, svc := range a.Services { + if len(svc.Extends) > 0 { + res = append(res, svc.Extends) + } + } + } + return +} + +func getAllExtendFunction(svc *parser.Service, ast *parser.Thrift, resolver *Resolver, args *config.Argument) (res []*parser.Function, err error) { + if len(svc.Extends) == 0 { + return + } + parts := semantic.SplitType(svc.Extends) + switch len(parts) { + case 1: + if resolver.mainPkg.Ast.Filename == ast.Filename { // extended current service for master IDL + extendSvc, found := ast.GetService(parts[0]) + if found { + funcs := extendSvc.GetFunctions() + // determine if it still has extends + extendFuncs, err := getAllExtendFunction(extendSvc, ast, resolver, args) + if err != nil { + return nil, err + } + res = append(res, append(funcs, extendFuncs...)...) + } + return res, nil + } else { // extended current service for other IDL + extendSvc, found := ast.GetService(parts[0]) + if found { + base, err := addResolverDependency(resolver, ast, args) + if err != nil { + return nil, err + } + funcs := extendSvc.GetFunctions() + for _, f := range funcs { + processExtendsType(f, base) + } + extendFuncs, err := getAllExtendFunction(extendSvc, ast, resolver, args) + if err != nil { + return nil, err + } + res = append(res, append(funcs, extendFuncs...)...) + } + return res, nil + } + case 2: + refAst, found := ast.GetReference(parts[0]) + base, err := addResolverDependency(resolver, refAst, args) + if err != nil { + return nil, err + } + // ff the service extends from other files, it has to resolve the dependencies of other files as well + for _, dep := range refAst.Includes { + _, err := addResolverDependency(resolver, dep.Reference, args) + if err != nil { + return nil, err + } + } + if found { + extendSvc, found := refAst.GetService(parts[1]) + if found { + funcs := extendSvc.GetFunctions() + for _, f := range funcs { + processExtendsType(f, base) + } + extendFuncs, err := getAllExtendFunction(extendSvc, refAst, resolver, args) + if err != nil { + return nil, err + } + res = append(res, append(funcs, extendFuncs...)...) + } + } + return res, nil + } + + return res, nil +} + +func processExtendsType(f *parser.Function, base string) { + // the method of other file is extended, and the package of req/resp needs to be changed + // ex. base.thrift -> Resp Method(Req){} + // base.Resp Method(base.Req){} + if len(f.Arguments) > 0 { + if f.Arguments[0].Type.Category.IsContainerType() { + switch f.Arguments[0].Type.Category { + case parser.Category_Set, parser.Category_List: + if !strings.Contains(f.Arguments[0].Type.ValueType.Name, ".") && f.Arguments[0].Type.ValueType.Category.IsStruct() { + f.Arguments[0].Type.ValueType.Name = base + "." + f.Arguments[0].Type.ValueType.Name + } + case parser.Category_Map: + if !strings.Contains(f.Arguments[0].Type.ValueType.Name, ".") && f.Arguments[0].Type.ValueType.Category.IsStruct() { + f.Arguments[0].Type.ValueType.Name = base + "." + f.Arguments[0].Type.ValueType.Name + } + if !strings.Contains(f.Arguments[0].Type.KeyType.Name, ".") && f.Arguments[0].Type.KeyType.Category.IsStruct() { + f.Arguments[0].Type.KeyType.Name = base + "." + f.Arguments[0].Type.KeyType.Name + } + } + } else { + if !strings.Contains(f.Arguments[0].Type.Name, ".") && f.Arguments[0].Type.Category.IsStruct() { + f.Arguments[0].Type.Name = base + "." + f.Arguments[0].Type.Name + } + } + } + + if f.FunctionType.Category.IsContainerType() { + switch f.FunctionType.Category { + case parser.Category_Set, parser.Category_List: + if !strings.Contains(f.FunctionType.ValueType.Name, ".") && f.FunctionType.ValueType.Category.IsStruct() { + f.FunctionType.ValueType.Name = base + "." + f.FunctionType.ValueType.Name + } + case parser.Category_Map: + if !strings.Contains(f.FunctionType.ValueType.Name, ".") && f.FunctionType.ValueType.Category.IsStruct() { + f.FunctionType.ValueType.Name = base + "." + f.FunctionType.ValueType.Name + } + if !strings.Contains(f.FunctionType.KeyType.Name, ".") && f.FunctionType.KeyType.Category.IsStruct() { + f.FunctionType.KeyType.Name = base + "." + f.FunctionType.KeyType.Name + } + } + } else { + if !strings.Contains(f.FunctionType.Name, ".") && f.FunctionType.Category.IsStruct() { + f.FunctionType.Name = base + "." + f.FunctionType.Name + } + } +} + +func getUniqueResolveDependentName(name string, resolver *Resolver) string { + rawName := name + for i := 0; i < 10000; i++ { + if _, exist := resolver.deps[name]; !exist { + return name + } + name = rawName + fmt.Sprint(i) + } + + return name +} + +func addResolverDependency(resolver *Resolver, ast *parser.Thrift, args *config.Argument) (string, error) { + namespace, err := resolver.LoadOne(ast) + if err != nil { + return "", err + } + baseName := util.BaseName(ast.Filename, ".thrift") + if refPkg, exist := resolver.refPkgs[baseName]; !exist { + resolver.deps[baseName] = namespace + } else { + if ast.Filename != refPkg.Ast.Filename { + baseName = getUniqueResolveDependentName(baseName, resolver) + resolver.deps[baseName] = namespace + } + } + pkg := getGoPackage(ast, args.OptPkgMap) + impt := ast.Filename + pkgName := util.SplitPackageName(pkg, "") + pkgName, err = util.GetPackageUniqueName(pkgName) + if err != nil { + return "", err + } + ref := &PackageReference{baseName, impt, &model.Model{ + FilePath: ast.Filename, + Package: pkg, + PackageName: pkgName, + }, ast, false} + if _, exist := resolver.refPkgs[baseName]; !exist { + resolver.refPkgs[baseName] = ref + } + + return baseName, nil +} + +/*---------------------------Model-----------------------------*/ + +var BaseThrift = parser.Thrift{} + +var baseTypes = map[string]string{ + "bool": "bool", + "byte": "int8", + "i8": "int8", + "i16": "int16", + "i32": "int32", + "i64": "int64", + "double": "float64", + "string": "string", + "binary": "[]byte", +} + +func switchBaseType(typ *parser.Type) *model.Type { + switch typ.Name { + case "bool": + return model.TypeBool + case "byte": + return model.TypeByte + case "i8": + return model.TypeInt8 + case "i16": + return model.TypeInt16 + case "i32": + return model.TypeInt32 + case "i64": + return model.TypeInt64 + case "int": + return model.TypeInt + case "double": + return model.TypeFloat64 + case "string": + return model.TypeString + case "binary": + return model.TypeBinary + } + return nil +} + +func newBaseType(typ *model.Type, cg model.Category) *model.Type { + cyp := *typ + cyp.Category = cg + return &cyp +} + +func newStructType(name string, cg model.Category) *model.Type { + return &model.Type{ + Name: name, + Scope: nil, + Kind: model.KindStruct, + Category: cg, + Indirect: false, + Extra: nil, + HasNew: true, + } +} + +func newEnumType(name string, cg model.Category) *model.Type { + return &model.Type{ + Name: name, + Scope: &model.BaseModel, + Kind: model.KindInt, + Category: cg, + } +} + +func newFuncType(name string, cg model.Category) *model.Type { + return &model.Type{ + Name: name, + Scope: nil, + Kind: model.KindFunc, + Category: cg, + Indirect: false, + Extra: nil, + HasNew: false, + } +} + +func (resolver *Resolver) getFieldType(typ *parser.Type) (*model.Type, error) { + if dt, _ := resolver.getBaseType(typ); dt != nil { + return dt, nil + } + sb := resolver.Get(typ.Name) + if sb != nil { + return sb.Type, nil + } + return nil, fmt.Errorf("unknown type: %s", typ.Name) +} + +type ResolvedSymbol struct { + Base string + Src string + *Symbol +} + +func (rs ResolvedSymbol) Expression() string { + base, err := NameStyle.Identify(rs.Base) + if err != nil { + logs.Warnf("%s naming style for %s failed, fall back to %s, please refer to the variable manually!", NameStyle.Name(), rs.Base, rs.Base) + base = rs.Base + } + // base type no need to do name style + if model.IsBaseType(rs.Type) { + // base type mapping + if val, exist := baseTypes[rs.Base]; exist { + base = val + } + } + if rs.Src != "" { + if !rs.IsValue && model.IsBaseType(rs.Type) { + return base + } + return fmt.Sprintf("%s.%s", rs.Src, base) + } + return base +} + +func astToModel(ast *parser.Thrift, rs *Resolver) (*model.Model, error) { + main := rs.mainPkg.Model + if main == nil { + main = new(model.Model) + } + + // typedefs + tds := ast.GetTypedefs() + typdefs := make([]model.TypeDef, 0, len(tds)) + for _, t := range tds { + td := model.TypeDef{ + Scope: main, + Alias: t.Alias, + } + if bt, err := rs.ResolveType(t.Type); bt == nil || err != nil { + return nil, fmt.Errorf("%s has no type definition, error: %s", t.String(), err) + } else { + td.Type = bt + } + typdefs = append(typdefs, td) + } + main.Typedefs = typdefs + + // constants + cts := ast.GetConstants() + constants := make([]model.Constant, 0, len(cts)) + variables := make([]model.Variable, 0, len(cts)) + for _, c := range cts { + ft, err := rs.ResolveType(c.Type) + if err != nil { + return nil, err + } + if ft.Name == model.TypeBaseList.Name || ft.Name == model.TypeBaseMap.Name || ft.Name == model.TypeBaseSet.Name { + resolveValue, err := rs.ResolveConstantValue(c.Value) + if err != nil { + return nil, err + } + vt := model.Variable{ + Scope: main, + Name: c.Name, + Type: ft, + Value: resolveValue, + } + variables = append(variables, vt) + } else { + resolveValue, err := rs.ResolveConstantValue(c.Value) + if err != nil { + return nil, err + } + ct := model.Constant{ + Scope: main, + Name: c.Name, + Type: ft, + Value: resolveValue, + } + constants = append(constants, ct) + } + } + main.Constants = constants + main.Variables = variables + + // Enums + ems := ast.GetEnums() + enums := make([]model.Enum, 0, len(ems)) + for _, e := range ems { + em := model.Enum{ + Scope: main, + Name: e.GetName(), + GoType: "int64", + } + vs := make([]model.Constant, 0, len(e.Values)) + for _, ee := range e.Values { + vs = append(vs, model.Constant{ + Scope: main, + Name: ee.Name, + Type: model.TypeInt64, + Value: model.IntExpression{Src: int(ee.Value)}, + }) + } + em.Values = vs + enums = append(enums, em) + } + main.Enums = enums + + // Structs + sts := make([]*parser.StructLike, 0, len(ast.Structs)) + sts = append(sts, ast.Structs...) + structs := make([]model.Struct, 0, len(ast.Structs)+len(ast.Unions)+len(ast.Exceptions)) + for _, st := range sts { + s := model.Struct{ + Scope: main, + Name: st.GetName(), + Category: model.CategoryStruct, + LeadingComments: removeCommentsSlash(st.GetReservedComments()), + } + + vs := make([]model.Field, 0, len(st.Fields)) + for _, f := range st.Fields { + fieldName, _ := (&styles.ThriftGo{}).Identify(f.Name) + isP, err := isPointer(f, rs) + if err != nil { + return nil, err + } + resolveType, err := rs.ResolveType(f.Type) + if err != nil { + return nil, err + } + field := model.Field{ + Scope: &s, + Name: fieldName, + Type: resolveType, + // IsSetDefault: f.IsSetDefault(), + LeadingComments: removeCommentsSlash(f.GetReservedComments()), + IsPointer: isP, + } + err = injectTags(f, &field, true, true) + if err != nil { + return nil, err + } + vs = append(vs, field) + } + checkDuplicatedFileName(vs) + s.Fields = vs + structs = append(structs, s) + } + + sts = make([]*parser.StructLike, 0, len(ast.Unions)) + sts = append(sts, ast.Unions...) + for _, st := range sts { + s := model.Struct{ + Scope: main, + Name: st.GetName(), + Category: model.CategoryUnion, + LeadingComments: removeCommentsSlash(st.GetReservedComments()), + } + vs := make([]model.Field, 0, len(st.Fields)) + for _, f := range st.Fields { + fieldName, _ := (&styles.ThriftGo{}).Identify(f.Name) + isP, err := isPointer(f, rs) + if err != nil { + return nil, err + } + resolveType, err := rs.ResolveType(f.Type) + if err != nil { + return nil, err + } + field := model.Field{ + Scope: &s, + Name: fieldName, + Type: resolveType, + LeadingComments: removeCommentsSlash(f.GetReservedComments()), + IsPointer: isP, + } + err = injectTags(f, &field, true, true) + if err != nil { + return nil, err + } + vs = append(vs, field) + } + checkDuplicatedFileName(vs) + s.Fields = vs + structs = append(structs, s) + } + + sts = make([]*parser.StructLike, 0, len(ast.Exceptions)) + sts = append(sts, ast.Exceptions...) + for _, st := range sts { + s := model.Struct{ + Scope: main, + Name: st.GetName(), + Category: model.CategoryException, + LeadingComments: removeCommentsSlash(st.GetReservedComments()), + } + vs := make([]model.Field, 0, len(st.Fields)) + for _, f := range st.Fields { + fieldName, _ := (&styles.ThriftGo{}).Identify(f.Name) + isP, err := isPointer(f, rs) + if err != nil { + return nil, err + } + resolveType, err := rs.ResolveType(f.Type) + if err != nil { + return nil, err + } + field := model.Field{ + Scope: &s, + Name: fieldName, + Type: resolveType, + LeadingComments: removeCommentsSlash(f.GetReservedComments()), + IsPointer: isP, + } + err = injectTags(f, &field, true, true) + if err != nil { + return nil, err + } + vs = append(vs, field) + } + checkDuplicatedFileName(vs) + s.Fields = vs + structs = append(structs, s) + } + main.Structs = structs + + // In case of only the service refers another model, therefore scanning service is necessary + ss := ast.GetServices() + var err error + for _, s := range ss { + for _, m := range s.GetFunctions() { + _, err = rs.ResolveType(m.GetFunctionType()) + if err != nil { + return nil, err + } + for _, a := range m.GetArguments() { + _, err = rs.ResolveType(a.GetType()) + if err != nil { + return nil, err + } + } + } + } + + return main, nil +} + +// removeCommentsSlash can remove double slash for comments with thrift +func removeCommentsSlash(comments string) string { + if comments == "" { + return "" + } + + return comments[2:] +} + +func isPointer(f *parser.Field, rs *Resolver) (bool, error) { + typ, err := rs.ResolveType(f.GetType()) + if err != nil { + return false, err + } + if typ == nil { + return false, fmt.Errorf("can not get type: %s for %s", f.GetType(), f.GetName()) + } + if typ.Kind == model.KindStruct || typ.Kind == model.KindMap || typ.Kind == model.KindSlice { + return false, nil + } + + if f.GetRequiredness().IsOptional() { + return true, nil + } else { + return false, nil + } +} + +func getNewFieldName(fieldName string, fieldNameSet map[string]bool) string { + if _, ex := fieldNameSet[fieldName]; ex { + fieldName = fieldName + "_" + return getNewFieldName(fieldName, fieldNameSet) + } + return fieldName +} + +func checkDuplicatedFileName(vs []model.Field) { + fieldNameSet := make(map[string]bool) + for i := 0; i < len(vs); i++ { + if _, ex := fieldNameSet[vs[i].Name]; ex { + newName := getNewFieldName(vs[i].Name, fieldNameSet) + fieldNameSet[newName] = true + vs[i].Name = newName + } else { + fieldNameSet[vs[i].Name] = true + } + } +} diff --git a/thrift/plugin.go b/thrift/plugin.go new file mode 100644 index 0000000..c315c9d --- /dev/null +++ b/thrift/plugin.go @@ -0,0 +1,445 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "errors" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strings" + + "github.com/cloudwego/hertz/cmd/hz/config" + "github.com/cloudwego/hertz/cmd/hz/generator" + "github.com/cloudwego/hertz/cmd/hz/generator/model" + "github.com/cloudwego/hertz/cmd/hz/meta" + "github.com/cloudwego/hertz/cmd/hz/util" + "github.com/cloudwego/hertz/cmd/hz/util/logs" + "github.com/cloudwego/thriftgo/generator/backend" + "github.com/cloudwego/thriftgo/generator/golang" + "github.com/cloudwego/thriftgo/generator/golang/styles" + "github.com/cloudwego/thriftgo/parser" + thriftgo_plugin "github.com/cloudwego/thriftgo/plugin" +) + +type Plugin struct { + req *thriftgo_plugin.Request + args *config.Argument + logger *logs.StdLogger + rmTags []string +} + +func (plugin *Plugin) Run() int { + plugin.setLogger() + args := &config.Argument{} + defer func() { + if args == nil { + return + } + if args.Verbose { + verboseLog := plugin.recvVerboseLogger() + if len(verboseLog) != 0 { + fmt.Fprintf(os.Stderr, verboseLog) + } + } else { + warning := plugin.recvWarningLogger() + if len(warning) != 0 { + fmt.Fprintf(os.Stderr, warning) + } + } + }() + + err := plugin.handleRequest() + if err != nil { + logs.Errorf("handle request failed: %s", err.Error()) + return meta.PluginError + } + + args, err = plugin.parseArgs() + if err != nil { + logs.Errorf("parse args failed: %s", err.Error()) + return meta.PluginError + } + plugin.rmTags = args.RmTags + if args.CmdType == meta.CmdModel { + // check tag options for model mode + CheckTagOption(plugin.args) + res, err := plugin.GetResponse(nil, args.OutDir) + if err != nil { + logs.Errorf("get response failed: %s", err.Error()) + return meta.PluginError + } + plugin.response(res) + if err != nil { + logs.Errorf("response failed: %s", err.Error()) + return meta.PluginError + } + return 0 + } + + err = plugin.initNameStyle() + if err != nil { + logs.Errorf("init naming style failed: %s", err.Error()) + return meta.PluginError + } + + options := CheckTagOption(plugin.args) + + pkgInfo, err := plugin.getPackageInfo() + if err != nil { + logs.Errorf("get http package info failed: %s", err.Error()) + return meta.PluginError + } + + customPackageTemplate := args.CustomizePackage + pkg, err := args.GetGoPackage() + if err != nil { + logs.Errorf("get go package failed: %s", err.Error()) + return meta.PluginError + } + handlerDir, err := args.GetHandlerDir() + if err != nil { + logs.Errorf("get handler dir failed: %s", err.Error()) + return meta.PluginError + } + routerDir, err := args.GetRouterDir() + if err != nil { + logs.Errorf("get router dir failed: %s", err.Error()) + return meta.PluginError + } + modelDir, err := args.GetModelDir() + if err != nil { + logs.Errorf("get model dir failed: %s", err.Error()) + return meta.PluginError + } + clientDir, err := args.GetClientDir() + if err != nil { + logs.Errorf("get client dir failed: %s", err.Error()) + return meta.PluginError + } + sg := generator.HttpPackageGenerator{ + ConfigPath: customPackageTemplate, + HandlerDir: handlerDir, + RouterDir: routerDir, + ModelDir: modelDir, + UseDir: args.Use, + ClientDir: clientDir, + TemplateGenerator: generator.TemplateGenerator{ + OutputDir: args.OutDir, + Excludes: args.Excludes, + }, + ProjPackage: pkg, + Options: options, + HandlerByMethod: args.HandlerByMethod, + CmdType: args.CmdType, + IdlClientDir: util.SubDir(modelDir, pkgInfo.Package), + ForceClientDir: args.ForceClientDir, + BaseDomain: args.BaseDomain, + SnakeStyleMiddleware: args.SnakeStyleMiddleware, + } + if args.ModelBackend != "" { + sg.Backend = meta.Backend(args.ModelBackend) + } + generator.SetDefaultTemplateConfig() + + err = sg.Generate(pkgInfo) + if err != nil { + logs.Errorf("generate package failed: %s", err.Error()) + return meta.PluginError + } + if len(args.Use) != 0 { + err = sg.Persist() + if err != nil { + logs.Errorf("persist file failed within '-use' option: %s", err.Error()) + return meta.PluginError + } + res := thriftgo_plugin.BuildErrorResponse(errors.New(meta.TheUseOptionMessage).Error()) + err = plugin.response(res) + if err != nil { + logs.Errorf("response failed: %s", err.Error()) + return meta.PluginError + } + return 0 + } + files, err := sg.GetFormatAndExcludedFiles() + if err != nil { + logs.Errorf("format file failed: %s", err.Error()) + return meta.PluginError + } + res, err := plugin.GetResponse(files, sg.OutputDir) + if err != nil { + logs.Errorf("get response failed: %s", err.Error()) + return meta.PluginError + } + err = plugin.response(res) + if err != nil { + logs.Errorf("response failed: %s", err.Error()) + return meta.PluginError + } + return 0 +} + +func (plugin *Plugin) setLogger() { + plugin.logger = logs.NewStdLogger(logs.LevelInfo) + plugin.logger.Defer = true + plugin.logger.ErrOnly = true + logs.SetLogger(plugin.logger) +} + +func (plugin *Plugin) recvWarningLogger() string { + warns := plugin.logger.Warn() + plugin.logger.Flush() + logs.SetLogger(logs.NewStdLogger(logs.LevelInfo)) + return warns +} + +func (plugin *Plugin) recvVerboseLogger() string { + info := plugin.logger.Out() + warns := plugin.logger.Warn() + verboseLog := string(info) + warns + plugin.logger.Flush() + logs.SetLogger(logs.NewStdLogger(logs.LevelInfo)) + return verboseLog +} + +func (plugin *Plugin) handleRequest() error { + data, err := ioutil.ReadAll(os.Stdin) + if err != nil { + return fmt.Errorf("read request failed: %s", err.Error()) + } + req, err := thriftgo_plugin.UnmarshalRequest(data) + if err != nil { + return fmt.Errorf("unmarshal request failed: %s", err.Error()) + } + plugin.req = req + // init thriftgo utils + thriftgoUtil = golang.NewCodeUtils(backend.DummyLogFunc()) + thriftgoUtil.HandleOptions(req.GeneratorParameters) + + return nil +} + +func (plugin *Plugin) parseArgs() (*config.Argument, error) { + if plugin.req == nil { + return nil, fmt.Errorf("request is nil") + } + args := new(config.Argument) + err := args.Unpack(plugin.req.PluginParameters) + if err != nil { + logs.Errorf("unpack args failed: %s", err.Error()) + } + plugin.args = args + return args, nil +} + +// initNameStyle initializes the naming style based on the "naming_style" option for thrift. +func (plugin *Plugin) initNameStyle() error { + if len(plugin.args.ThriftOptions) == 0 { + return nil + } + for _, opt := range plugin.args.ThriftOptions { + parts := strings.SplitN(opt, "=", 2) + if len(parts) == 2 && parts[0] == "naming_style" { + NameStyle = styles.NewNamingStyle(parts[1]) + if NameStyle == nil { + return fmt.Errorf(fmt.Sprintf("do not support \"%s\" naming style", parts[1])) + } + break + } + } + + return nil +} + +func (plugin *Plugin) getPackageInfo() (*generator.HttpPackage, error) { + req := plugin.req + args := plugin.args + + ast := req.GetAST() + if ast == nil { + return nil, fmt.Errorf("no ast") + } + logs.Infof("Processing %s", ast.GetFilename()) + + pkgMap := args.OptPkgMap + pkg := getGoPackage(ast, pkgMap) + main := &model.Model{ + FilePath: ast.Filename, + Package: pkg, + PackageName: util.SplitPackageName(pkg, ""), + } + rs, err := NewResolver(ast, main, pkgMap) + if err != nil { + return nil, fmt.Errorf("new thrift resolver failed, err:%v", err) + } + err = rs.LoadAll(ast) + if err != nil { + return nil, err + } + + idlPackage := getGoPackage(ast, pkgMap) + if idlPackage == "" { + return nil, fmt.Errorf("go package for '%s' is not defined", ast.GetFilename()) + } + + services, err := astToService(ast, rs, args) + if err != nil { + return nil, err + } + var models model.Models + for _, s := range services { + models.MergeArray(s.Models) + } + + return &generator.HttpPackage{ + Services: services, + IdlName: ast.GetFilename(), + Package: idlPackage, + Models: models, + }, nil +} + +func (plugin *Plugin) response(res *thriftgo_plugin.Response) error { + data, err := thriftgo_plugin.MarshalResponse(res) + if err != nil { + return fmt.Errorf("marshal response failed: %s", err.Error()) + } + _, err = os.Stdout.Write(data) + if err != nil { + return fmt.Errorf("write response failed: %s", err.Error()) + } + return nil +} + +func (plugin *Plugin) InsertTag() ([]*thriftgo_plugin.Generated, error) { + var res []*thriftgo_plugin.Generated + + if plugin.args.NoRecurse { + outPath := plugin.req.OutputPath + packageName := getGoPackage(plugin.req.AST, nil) + fileName := util.BaseNameAndTrim(plugin.req.AST.GetFilename()) + ".go" + outPath = filepath.Join(outPath, packageName, fileName) + for _, st := range plugin.req.AST.Structs { + stName := st.GetName() + for _, f := range st.Fields { + fieldName := f.GetName() + tagString, err := getTagString(f, plugin.rmTags) + if err != nil { + return nil, err + } + insertPointer := "struct." + stName + "." + fieldName + "." + "tag" + gen := &thriftgo_plugin.Generated{ + Content: tagString, + Name: &outPath, + InsertionPoint: &insertPointer, + } + res = append(res, gen) + } + } + return res, nil + } + + for ast := range plugin.req.AST.DepthFirstSearch() { + outPath := plugin.req.OutputPath + packageName := getGoPackage(ast, nil) + fileName := util.BaseNameAndTrim(ast.GetFilename()) + ".go" + outPath = filepath.Join(outPath, packageName, fileName) + + for _, st := range ast.Structs { + stName := st.GetName() + for _, f := range st.Fields { + fieldName := f.GetName() + tagString, err := getTagString(f, plugin.rmTags) + if err != nil { + return nil, err + } + insertPointer := "struct." + stName + "." + fieldName + "." + "tag" + gen := &thriftgo_plugin.Generated{ + Content: tagString, + Name: &outPath, + InsertionPoint: &insertPointer, + } + res = append(res, gen) + } + } + } + return res, nil +} + +func (plugin *Plugin) GetResponse(files []generator.File, outputDir string) (*thriftgo_plugin.Response, error) { + var contents []*thriftgo_plugin.Generated + for _, file := range files { + filePath := filepath.Join(outputDir, file.Path) + content := &thriftgo_plugin.Generated{ + Content: file.Content, + Name: &filePath, + } + contents = append(contents, content) + } + + insertTag, err := plugin.InsertTag() + if err != nil { + return nil, err + } + + contents = append(contents, insertTag...) + + return &thriftgo_plugin.Response{ + Contents: contents, + }, nil +} + +func getTagString(f *parser.Field, rmTags []string) (string, error) { + field := model.Field{} + err := injectTags(f, &field, true, false) + if err != nil { + return "", err + } + disableTag := false + if v := getAnnotation(f.Annotations, AnnotationNone); len(v) > 0 { + if strings.EqualFold(v[0], "true") { + disableTag = true + } + } + + for _, rmTag := range rmTags { + for _, t := range field.Tags { + if t.IsDefault && strings.EqualFold(t.Key, rmTag) { + field.Tags.Remove(t.Key) + } + } + } + + var tagString string + tags := field.Tags + for idx, tag := range tags { + value := tag.Value + if disableTag { + value = "-" + } + if idx == 0 { + tagString += " " + tag.Key + ":\"" + value + "\"" + " " + } else if idx == len(tags)-1 { + tagString += tag.Key + ":\"" + value + "\"" + } else { + tagString += tag.Key + ":\"" + value + "\"" + " " + } + } + + return tagString, nil +} diff --git a/thrift/plugin_test.go b/thrift/plugin_test.go new file mode 100644 index 0000000..5e471b6 --- /dev/null +++ b/thrift/plugin_test.go @@ -0,0 +1,110 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "io/ioutil" + "testing" + + "github.com/cloudwego/hertz/cmd/hz/generator" + "github.com/cloudwego/hertz/cmd/hz/meta" + "github.com/cloudwego/thriftgo/plugin" +) + +func TestRun(t *testing.T) { + data, err := ioutil.ReadFile("../testdata/request_thrift.out") + if err != nil { + t.Fatal(err) + } + + req, err := plugin.UnmarshalRequest(data) + if err != nil { + t.Fatal(err) + } + + plu := new(Plugin) + plu.setLogger() + + plu.req = req + + _, err = plu.parseArgs() + if err != nil { + t.Fatal(err) + } + options := CheckTagOption(plu.args) + + pkgInfo, err := plu.getPackageInfo() + if err != nil { + t.Fatal(err) + } + + args := plu.args + customPackageTemplate := args.CustomizePackage + pkg, err := args.GetGoPackage() + if err != nil { + t.Fatal(err) + } + handlerDir, err := args.GetHandlerDir() + if err != nil { + t.Fatal(err) + } + routerDir, err := args.GetRouterDir() + if err != nil { + t.Fatal(err) + } + modelDir, err := args.GetModelDir() + if err != nil { + t.Fatal(err) + } + clientDir, err := args.GetClientDir() + if err != nil { + t.Fatal(err) + } + sg := generator.HttpPackageGenerator{ + ConfigPath: customPackageTemplate, + HandlerDir: handlerDir, + RouterDir: routerDir, + ModelDir: modelDir, + ClientDir: clientDir, + TemplateGenerator: generator.TemplateGenerator{ + OutputDir: args.OutDir, + }, + ProjPackage: pkg, + Options: options, + } + if args.ModelBackend != "" { + sg.Backend = meta.Backend(args.ModelBackend) + } + + err = sg.Generate(pkgInfo) + if err != nil { + t.Fatalf("generate package failed: %v", err) + } + files, err := sg.GetFormatAndExcludedFiles() + if err != nil { + return + } + + res, err := plu.GetResponse(files, sg.OutputDir) + if err != nil { + return + } + plu.response(res) + if err != nil { + return + } +} diff --git a/thrift/resolver.go b/thrift/resolver.go new file mode 100644 index 0000000..118d0fa --- /dev/null +++ b/thrift/resolver.go @@ -0,0 +1,592 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "fmt" + "strings" + + "github.com/cloudwego/hertz/cmd/hz/generator/model" + "github.com/cloudwego/hertz/cmd/hz/util" + "github.com/cloudwego/thriftgo/parser" +) + +var ( + ConstTrue = Symbol{ + IsValue: true, + Type: model.TypeBool, + Value: true, + Scope: &BaseThrift, + } + ConstFalse = Symbol{ + IsValue: true, + Type: model.TypeBool, + Value: false, + Scope: &BaseThrift, + } + ConstEmptyString = Symbol{ + IsValue: true, + Type: model.TypeString, + Value: "", + Scope: &BaseThrift, + } +) + +type PackageReference struct { + IncludeBase string + IncludePath string + Model *model.Model + Ast *parser.Thrift + Referred bool +} + +func getReferPkgMap(pkgMap map[string]string, incs []*parser.Include, mainModel *model.Model) (map[string]*PackageReference, error) { + var err error + out := make(map[string]*PackageReference, len(pkgMap)) + pkgAliasMap := make(map[string]string, len(incs)) + // bugfix: add main package to avoid namespace conflict + mainPkg := mainModel.Package + mainPkgName := mainModel.PackageName + mainPkgName, err = util.GetPackageUniqueName(mainPkgName) + if err != nil { + return nil, err + } + pkgAliasMap[mainPkg] = mainPkgName + for _, inc := range incs { + pkg := getGoPackage(inc.Reference, pkgMap) + impt := inc.GetPath() + base := util.BaseNameAndTrim(impt) + pkgName := util.SplitPackageName(pkg, "") + if pn, exist := pkgAliasMap[pkg]; exist { + pkgName = pn + } else { + pkgName, err = util.GetPackageUniqueName(pkgName) + pkgAliasMap[pkg] = pkgName + if err != nil { + return nil, fmt.Errorf("get package unique name failed, err: %v", err) + } + } + out[base] = &PackageReference{base, impt, &model.Model{ + FilePath: inc.Path, + Package: pkg, + PackageName: pkgName, + }, inc.Reference, false} + } + + return out, nil +} + +type Symbol struct { + IsValue bool + Type *model.Type + Value interface{} + Scope *parser.Thrift +} + +type NameSpace map[string]*Symbol + +type Resolver struct { + // idl symbols + root NameSpace + deps map[string]NameSpace + + // exported models + mainPkg PackageReference + refPkgs map[string]*PackageReference +} + +func NewResolver(ast *parser.Thrift, model *model.Model, pkgMap map[string]string) (*Resolver, error) { + pm, err := getReferPkgMap(pkgMap, ast.GetIncludes(), model) + if err != nil { + return nil, fmt.Errorf("get package map failed, err: %v", err) + } + file := ast.GetFilename() + return &Resolver{ + root: make(NameSpace), + deps: make(map[string]NameSpace), + refPkgs: pm, + mainPkg: PackageReference{ + IncludeBase: util.BaseNameAndTrim(file), + IncludePath: ast.GetFilename(), + Model: model, + Ast: ast, + Referred: false, + }, + }, nil +} + +func (resolver *Resolver) GetRefModel(includeBase string) (*model.Model, error) { + if includeBase == "" { + return resolver.mainPkg.Model, nil + } + ref, ok := resolver.refPkgs[includeBase] + if !ok { + return nil, fmt.Errorf("not found include %s", includeBase) + } + return ref.Model, nil +} + +func (resolver *Resolver) getBaseType(typ *parser.Type) (*model.Type, bool) { + tt := switchBaseType(typ) + if tt != nil { + return tt, true + } + if typ.Name == "map" { + t := *model.TypeBaseMap + return &t, false + } + if typ.Name == "list" { + t := *model.TypeBaseList + return &t, false + } + if typ.Name == "set" { + t := *model.TypeBaseList + return &t, false + } + return nil, false +} + +func (resolver *Resolver) ResolveType(typ *parser.Type) (*model.Type, error) { + bt, base := resolver.getBaseType(typ) + if bt != nil { + if base { + return bt, nil + } else { + if typ.Name == model.TypeBaseMap.Name { + resolveKey, err := resolver.ResolveType(typ.KeyType) + if err != nil { + return nil, err + } + resolveValue, err := resolver.ResolveType(typ.ValueType) + if err != nil { + return nil, err + } + bt.Extra = append(bt.Extra, resolveKey, resolveValue) + } else if typ.Name == model.TypeBaseList.Name || typ.Name == model.TypeBaseSet.Name { + resolveValue, err := resolver.ResolveType(typ.ValueType) + if err != nil { + return nil, err + } + bt.Extra = append(bt.Extra, resolveValue) + } else { + return nil, fmt.Errorf("invalid DefinitionType(%+v)", bt) + } + return bt, nil + } + } + + id := typ.GetName() + rs, err := resolver.ResolveIdentifier(id) + if err != nil { + return nil, err + } + sb := rs.Symbol + if sb == nil { + return nil, fmt.Errorf("not found identifier %s", id) + } + return sb.Type, nil +} + +func (resolver *Resolver) ResolveConstantValue(constant *parser.ConstValue) (model.Literal, error) { + switch constant.Type { + case parser.ConstType_ConstInt: + return model.IntExpression{Src: int(constant.TypedValue.GetInt())}, nil + case parser.ConstType_ConstDouble: + return model.DoubleExpression{Src: constant.TypedValue.GetDouble()}, nil + case parser.ConstType_ConstLiteral: + return model.StringExpression{Src: constant.TypedValue.GetLiteral()}, nil + case parser.ConstType_ConstList: + eleType, err := switchConstantType(constant.Type) + if err != nil { + return nil, err + } + ret := model.ListExpression{ + ElementType: eleType, + } + for _, i := range constant.TypedValue.List { + elem, err := resolver.ResolveConstantValue(i) + if err != nil { + return nil, err + } + ret.Elements = append(ret.Elements, elem) + } + return ret, nil + case parser.ConstType_ConstMap: + keyType, err := switchConstantType(constant.TypedValue.Map[0].Key.Type) + if err != nil { + return nil, err + } + valueType, err := switchConstantType(constant.TypedValue.Map[0].Value.Type) + if err != nil { + return nil, err + } + ret := model.MapExpression{ + KeyType: keyType, + ValueType: valueType, + Elements: make(map[string]model.Literal, len(constant.TypedValue.Map)), + } + for _, v := range constant.TypedValue.Map { + value, err := resolver.ResolveConstantValue(v.Value) + if err != nil { + return nil, err + } + ret.Elements[v.Key.String()] = value + } + return ret, nil + case parser.ConstType_ConstIdentifier: + return resolver.ResolveIdentifier(*constant.TypedValue.Identifier) + } + return model.StringExpression{Src: constant.String()}, nil +} + +func (resolver *Resolver) ResolveIdentifier(id string) (ret ResolvedSymbol, err error) { + sb := resolver.Get(id) + if sb == nil { + return ResolvedSymbol{}, fmt.Errorf("identifier '%s' not found", id) + } + ret.Symbol = sb + ret.Base = id + if sb.Scope == &BaseThrift { + return + } + if sb.Scope == resolver.mainPkg.Ast { + resolver.mainPkg.Referred = true + ret.Src = resolver.mainPkg.Model.PackageName + return + } + + sp := strings.SplitN(id, ".", 2) + if ref, ok := resolver.refPkgs[sp[0]]; ok { + ref.Referred = true + ret.Base = sp[1] + ret.Src = ref.Model.PackageName + ret.Type.Scope = ref.Model + } else { + return ResolvedSymbol{}, fmt.Errorf("can't resolve identifier '%s'", id) + } + + return +} + +func (resolver *Resolver) ResolveTypeName(typ *parser.Type) (string, error) { + if typ.GetIsTypedef() { + rt, err := resolver.ResolveIdentifier(typ.GetName()) + if err != nil { + return "", err + } + + return rt.Expression(), nil + } + switch typ.GetCategory() { + case parser.Category_Map: + keyType, err := resolver.ResolveTypeName(typ.GetKeyType()) + if err != nil { + return "", err + } + if typ.GetKeyType().GetCategory().IsStruct() { + keyType = "*" + keyType + } + valueType, err := resolver.ResolveTypeName(typ.GetValueType()) + if err != nil { + return "", err + } + if typ.GetValueType().GetCategory().IsStruct() { + valueType = "*" + valueType + } + return fmt.Sprintf("map[%s]%s", keyType, valueType), nil + case parser.Category_List, parser.Category_Set: + // list/set -> []element for thriftgo + // valueType refers the element type for list/set + elemType, err := resolver.ResolveTypeName(typ.GetValueType()) + if err != nil { + return "", err + } + if typ.GetValueType().GetCategory().IsStruct() { + elemType = "*" + elemType + } + return fmt.Sprintf("[]%s", elemType), err + } + rt, err := resolver.ResolveIdentifier(typ.GetName()) + if err != nil { + return "", err + } + + return rt.Expression(), nil +} + +func (resolver *Resolver) Get(name string) *Symbol { + s, ok := resolver.root[name] + if ok { + return s + } + if strings.Contains(name, ".") { + sp := strings.SplitN(name, ".", 2) + if ref, ok := resolver.deps[sp[0]]; ok { + if ss, ok := ref[sp[1]]; ok { + return ss + } + } + } + return nil +} + +func (resolver *Resolver) ExportReferred(all, needMain bool) (ret []*PackageReference) { + for _, v := range resolver.refPkgs { + if all { + ret = append(ret, v) + v.Referred = false + } else if v.Referred { + ret = append(ret, v) + v.Referred = false + } + } + if needMain && (all || resolver.mainPkg.Referred) { + ret = append(ret, &resolver.mainPkg) + } + resolver.mainPkg.Referred = false + return +} + +func (resolver *Resolver) LoadAll(ast *parser.Thrift) error { + var err error + resolver.root, err = resolver.LoadOne(ast) + if err != nil { + return fmt.Errorf("load root package: %s", err) + } + + includes := ast.GetIncludes() + astMap := make(map[string]NameSpace, len(includes)) + for _, dep := range includes { + bName := util.BaseName(dep.Path, ".thrift") + astMap[bName], err = resolver.LoadOne(dep.Reference) + if err != nil { + return fmt.Errorf("load idl %s: %s", dep.Path, err) + } + } + resolver.deps = astMap + for _, td := range ast.Typedefs { + name := td.GetAlias() + if _, ex := resolver.root[name]; ex { + if resolver.root[name].Type != nil { + typ := newTypedefType(resolver.root[name].Type, name) + resolver.root[name].Type = &typ + continue + } + } + sym := resolver.Get(td.Type.GetName()) + typ := newTypedefType(sym.Type, name) + resolver.root[name].Type = &typ + } + return nil +} + +func LoadBaseIdentifier() NameSpace { + ret := make(NameSpace, 16) + + ret["true"] = &ConstTrue + ret["false"] = &ConstFalse + ret[`""`] = &ConstEmptyString + ret["bool"] = &Symbol{ + Type: model.TypeBool, + Scope: &BaseThrift, + } + ret["byte"] = &Symbol{ + Type: model.TypeByte, + Scope: &BaseThrift, + } + ret["i8"] = &Symbol{ + Type: model.TypeInt8, + Scope: &BaseThrift, + } + ret["i16"] = &Symbol{ + Type: model.TypeInt16, + Scope: &BaseThrift, + } + ret["i32"] = &Symbol{ + Type: model.TypeInt32, + Scope: &BaseThrift, + } + ret["i64"] = &Symbol{ + Type: model.TypeInt64, + Scope: &BaseThrift, + } + ret["int"] = &Symbol{ + Type: model.TypeInt, + Scope: &BaseThrift, + } + ret["double"] = &Symbol{ + Type: model.TypeFloat64, + Scope: &BaseThrift, + } + ret["string"] = &Symbol{ + Type: model.TypeString, + Scope: &BaseThrift, + } + ret["binary"] = &Symbol{ + Type: model.TypeBinary, + Scope: &BaseThrift, + } + ret["list"] = &Symbol{ + Type: model.TypeBaseList, + Scope: &BaseThrift, + } + ret["set"] = &Symbol{ + Type: model.TypeBaseSet, + Scope: &BaseThrift, + } + ret["map"] = &Symbol{ + Type: model.TypeBaseMap, + Scope: &BaseThrift, + } + return ret +} + +func (resolver *Resolver) LoadOne(ast *parser.Thrift) (NameSpace, error) { + ret := LoadBaseIdentifier() + + for _, e := range ast.Enums { + prefix := e.GetName() + ret[prefix] = &Symbol{ + IsValue: false, + Value: e, + Scope: ast, + Type: newEnumType(prefix, model.CategoryEnum), + } + for _, ee := range e.Values { + name := prefix + "." + ee.GetName() + if _, exist := ret[name]; exist { + return nil, fmt.Errorf("duplicated identifier '%s' in %s", name, ast.Filename) + } + + ret[name] = &Symbol{ + IsValue: true, + Value: ee, + Scope: ast, + Type: newBaseType(model.TypeInt, model.CategoryEnum), + } + } + } + + for _, e := range ast.Constants { + name := e.GetName() + if _, exist := ret[name]; exist { + return nil, fmt.Errorf("duplicated identifier '%s' in %s", name, ast.Filename) + } + gt, _ := resolver.getBaseType(e.Type) + ret[name] = &Symbol{ + IsValue: true, + Value: e, + Scope: ast, + Type: gt, + } + } + + for _, e := range ast.Structs { + name := e.GetName() + if _, exist := ret[name]; exist { + return nil, fmt.Errorf("duplicated identifier '%s' in %s", name, ast.Filename) + } + ret[name] = &Symbol{ + IsValue: false, + Value: e, + Scope: ast, + Type: newStructType(name, model.CategoryStruct), + } + } + + for _, e := range ast.Unions { + name := e.GetName() + if _, exist := ret[name]; exist { + return nil, fmt.Errorf("duplicated identifier '%s' in %s", name, ast.Filename) + } + ret[name] = &Symbol{ + IsValue: false, + Value: e, + Scope: ast, + Type: newStructType(name, model.CategoryStruct), + } + } + + for _, e := range ast.Exceptions { + name := e.GetName() + if _, exist := ret[name]; exist { + return nil, fmt.Errorf("duplicated identifier '%s' in %s", name, ast.Filename) + } + ret[name] = &Symbol{ + IsValue: false, + Value: e, + Scope: ast, + Type: newStructType(name, model.CategoryStruct), + } + } + + for _, e := range ast.Services { + name := e.GetName() + if _, exist := ret[name]; exist { + return nil, fmt.Errorf("duplicated identifier '%s' in %s", name, ast.Filename) + } + ret[name] = &Symbol{ + IsValue: false, + Value: e, + Scope: ast, + Type: newFuncType(name, model.CategoryService), + } + } + + for _, td := range ast.Typedefs { + name := td.GetAlias() + if _, exist := ret[name]; exist { + return nil, fmt.Errorf("duplicated identifier '%s' in %s", name, ast.Filename) + } + gt, _ := resolver.getBaseType(td.Type) + if gt == nil { + sym := ret[td.Type.Name] + if sym != nil { + gt = sym.Type + } + } + ret[name] = &Symbol{ + IsValue: false, + Value: td, + Scope: ast, + Type: gt, + } + } + + return ret, nil +} + +func switchConstantType(constant parser.ConstType) (*model.Type, error) { + switch constant { + case parser.ConstType_ConstInt: + return model.TypeInt, nil + case parser.ConstType_ConstDouble: + return model.TypeFloat64, nil + case parser.ConstType_ConstLiteral: + return model.TypeString, nil + default: + return nil, fmt.Errorf("unknown constant type %d", constant) + } +} + +func newTypedefType(t *model.Type, name string) model.Type { + tmp := t + typ := *tmp + typ.Name = name + typ.Category = model.CategoryTypedef + return typ +} diff --git a/thrift/tag_test.go b/thrift/tag_test.go new file mode 100644 index 0000000..108dea8 --- /dev/null +++ b/thrift/tag_test.go @@ -0,0 +1,129 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "io/ioutil" + "strings" + "testing" + + "github.com/cloudwego/hertz/cmd/hz/config" + "github.com/cloudwego/thriftgo/plugin" +) + +func TestInsertTag(t *testing.T) { + data, err := ioutil.ReadFile("./test_data/thrift_tag_test.out") + if err != nil { + t.Fatal(err) + } + req, err := plugin.UnmarshalRequest(data) + if err != nil { + t.Fatal(err) + } + + plu := new(Plugin) + plu.req = req + plu.args = new(config.Argument) + + type TagStruct struct { + Annotation string + GeneratedTag string + ActualTag string + } + + tagList := []TagStruct{ + { + Annotation: "query", + GeneratedTag: "json:\"DefaultQueryTag\" query:\"query\"", + }, + { + Annotation: "raw_body", + GeneratedTag: "json:\"RawBodyTag\" raw_body:\"raw_body\"", + }, + { + Annotation: "path", + GeneratedTag: "json:\"PathTag\" path:\"path\"", + }, + { + Annotation: "form", + GeneratedTag: "form:\"form\" json:\"FormTag\"", + }, + { + Annotation: "cookie", + GeneratedTag: "cookie:\"cookie\" json:\"CookieTag\"", + }, + { + Annotation: "header", + GeneratedTag: "header:\"header\" json:\"HeaderTag\"", + }, + { + Annotation: "body", + GeneratedTag: "form:\"body\" json:\"body\"", + }, + { + Annotation: "go.tag", + GeneratedTag: "", + }, + { + Annotation: "vd", + GeneratedTag: "form:\"VdTag\" json:\"VdTag\" query:\"VdTag\" vd:\"$!='?'\"", + }, + { + Annotation: "non", + GeneratedTag: "form:\"DefaultTag\" json:\"DefaultTag\" query:\"DefaultTag\"", + }, + { + Annotation: "query required", + GeneratedTag: "json:\"ReqQuery,required\" query:\"query,required\"", + }, + { + Annotation: "query optional", + GeneratedTag: "json:\"OptQuery,omitempty\" query:\"query\"", + }, + { + Annotation: "body required", + GeneratedTag: "form:\"body,required\" json:\"body,required\"", + }, + { + Annotation: "body optional", + GeneratedTag: "form:\"body\" json:\"body,omitempty\"", + }, + { + Annotation: "go.tag required", + GeneratedTag: "form:\"ReqGoTag,required\" query:\"ReqGoTag,required\"", + }, + { + Annotation: "go.tag optional", + GeneratedTag: "form:\"OptGoTag\" query:\"OptGoTag\"", + }, + { + Annotation: "go tag cover query", + GeneratedTag: "form:\"QueryGoTag,required\" json:\"QueryGoTag,required\"", + }, + } + + tags, err := plu.InsertTag() + if err != nil { + t.Fatal(err) + } + for i, tag := range tags { + tagList[i].ActualTag = tag.Content + if !strings.Contains(tagList[i].ActualTag, tagList[i].GeneratedTag) { + t.Fatalf("expected tag: '%s', but autual tag: '%s'", tagList[i].GeneratedTag, tagList[i].ActualTag) + } + } +} diff --git a/thrift/tags.go b/thrift/tags.go new file mode 100644 index 0000000..115d79b --- /dev/null +++ b/thrift/tags.go @@ -0,0 +1,370 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "fmt" + "sort" + "strconv" + "strings" + + "github.com/cloudwego/hertz/cmd/hz/config" + "github.com/cloudwego/hertz/cmd/hz/generator" + "github.com/cloudwego/hertz/cmd/hz/generator/model" + "github.com/cloudwego/hertz/cmd/hz/util" + "github.com/cloudwego/thriftgo/parser" +) + +const ( + AnnotationQuery = "api.query" + AnnotationForm = "api.form" + AnnotationPath = "api.path" + AnnotationHeader = "api.header" + AnnotationCookie = "api.cookie" + AnnotationBody = "api.body" + AnnotationRawBody = "api.raw_body" + AnnotationJsConv = "api.js_conv" + AnnotationNone = "api.none" + AnnotationFileName = "api.file_name" + + AnnotationValidator = "api.vd" + + AnnotationGoTag = "go.tag" +) + +const ( + ApiGet = "api.get" + ApiPost = "api.post" + ApiPut = "api.put" + ApiPatch = "api.patch" + ApiDelete = "api.delete" + ApiOptions = "api.options" + ApiHEAD = "api.head" + ApiAny = "api.any" + ApiPath = "api.path" + ApiSerializer = "api.serializer" + ApiGenPath = "api.handler_path" +) + +const ( + ApiBaseDomain = "api.base_domain" + ApiServiceGroup = "api.service_group" + ApiServiceGenDir = "api.service_gen_dir" // handler_dir for handler_by_service + ApiServicePath = "api.service_path" // declare the path to the service's handler according to this annotation for handler_by_method +) + +var ( + HttpMethodAnnotations = map[string]string{ + ApiGet: "GET", + ApiPost: "POST", + ApiPut: "PUT", + ApiPatch: "PATCH", + ApiDelete: "DELETE", + ApiOptions: "OPTIONS", + ApiHEAD: "HEAD", + ApiAny: "ANY", + } + + HttpMethodOptionAnnotations = map[string]string{ + ApiGenPath: "handler_path", + } + + BindingTags = map[string]string{ + AnnotationPath: "path", + AnnotationQuery: "query", + AnnotationHeader: "header", + AnnotationCookie: "cookie", + AnnotationBody: "json", + AnnotationForm: "form", + AnnotationRawBody: "raw_body", + } + + SerializerTags = map[string]string{ + ApiSerializer: "serializer", + } + + ValidatorTags = map[string]string{AnnotationValidator: "vd"} +) + +var ( + jsonSnakeName = false + unsetOmitempty = false +) + +func CheckTagOption(args *config.Argument) []generator.Option { + var ret []generator.Option + if args == nil { + return ret + } + if args.SnakeName { + jsonSnakeName = true + } + if args.UnsetOmitempty { + unsetOmitempty = true + } + if args.JSONEnumStr { + ret = append(ret, generator.OptionMarshalEnumToText) + } + return ret +} + +func checkSnakeName(name string) string { + if jsonSnakeName { + name = util.ToSnakeCase(name) + } + return name +} + +func getAnnotation(input parser.Annotations, target string) []string { + if len(input) == 0 { + return nil + } + for _, anno := range input { + if strings.ToLower(anno.Key) == target { + return anno.Values + } + } + + return []string{} +} + +type httpAnnotation struct { + method string + path []string +} + +type httpAnnotations []httpAnnotation + +func (s httpAnnotations) Len() int { + return len(s) +} + +func (s httpAnnotations) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} + +func (s httpAnnotations) Less(i, j int) bool { + return s[i].method < s[j].method +} + +func getAnnotations(input parser.Annotations, targets map[string]string) map[string][]string { + if len(input) == 0 || len(targets) == 0 { + return nil + } + out := map[string][]string{} + for k, t := range targets { + var ret *parser.Annotation + for _, anno := range input { + if strings.ToLower(anno.Key) == k { + ret = anno + break + } + } + if ret == nil { + continue + } + out[t] = ret.Values + } + return out +} + +func defaultBindingTags(f *parser.Field) []model.Tag { + out := make([]model.Tag, 3) + bindingTags := []string{ + AnnotationQuery, + AnnotationForm, + AnnotationPath, + AnnotationHeader, + AnnotationCookie, + AnnotationBody, + AnnotationRawBody, + } + + for _, tag := range bindingTags { + if v := getAnnotation(f.Annotations, tag); len(v) > 0 { + out[0] = jsonTag(f) + return out[:1] + } + } + + if v := getAnnotation(f.Annotations, AnnotationBody); len(v) > 0 { + val := getJsonValue(f, v[0]) + out[0] = tag("json", val) + } else { + t := jsonTag(f) + t.IsDefault = true + out[0] = t + } + if v := getAnnotation(f.Annotations, AnnotationQuery); len(v) > 0 { + val := checkRequire(f, v[0]) + out[1] = tag(BindingTags[AnnotationQuery], val) + } else { + val := checkRequire(f, checkSnakeName(f.Name)) + t := tag(BindingTags[AnnotationQuery], val) + t.IsDefault = true + out[1] = t + } + if v := getAnnotation(f.Annotations, AnnotationForm); len(v) > 0 { + val := checkRequire(f, v[0]) + out[2] = tag(BindingTags[AnnotationForm], val) + } else { + val := checkRequire(f, checkSnakeName(f.Name)) + t := tag(BindingTags[AnnotationForm], val) + t.IsDefault = true + out[2] = t + } + return out +} + +func jsonTag(f *parser.Field) (ret model.Tag) { + ret.Key = "json" + ret.Value = checkSnakeName(f.Name) + + if v := getAnnotation(f.Annotations, AnnotationJsConv); len(v) > 0 { + ret.Value += ",string" + } + if !unsetOmitempty && f.Requiredness == parser.FieldType_Optional { + ret.Value += ",omitempty" + } else if f.Requiredness == parser.FieldType_Required { + ret.Value += ",required" + } + return +} + +func tag(k, v string) model.Tag { + return model.Tag{ + Key: k, + Value: v, + } +} + +func annotationToTags(as parser.Annotations, targets map[string]string) (tags []model.Tag) { + rets := getAnnotations(as, targets) + for k, v := range rets { + for _, vv := range v { + tags = append(tags, model.Tag{ + Key: k, + Value: vv, + }) + } + } + return +} + +func injectTags(f *parser.Field, gf *model.Field, needDefault, needGoTag bool) error { + as := f.Annotations + if as == nil { + as = parser.Annotations{} + } + tags := gf.Tags + if tags == nil { + tags = make([]model.Tag, 0, len(as)) + } + + if needDefault { + tags = append(tags, defaultBindingTags(f)...) + } + + // binding tags + bts := annotationToTags(as, BindingTags) + for _, t := range bts { + key := t.Key + tags.Remove(key) + if key == "json" { + formVal := t.Value + t.Value = getJsonValue(f, t.Value) + formVal = checkRequire(f, formVal) + tags = append(tags, tag("form", formVal)) + } else { + t.Value = checkRequire(f, t.Value) + } + tags = append(tags, t) + } + + // validator tags + tags = append(tags, annotationToTags(as, ValidatorTags)...) + + // the tag defined by gotag with higher priority + checkGoTag(as, &tags) + + // go.tags for compiler mode + if needGoTag { + rets := getAnnotation(as, AnnotationGoTag) + for _, v := range rets { + gts := util.SplitGoTags(v) + for _, gt := range gts { + sp := strings.SplitN(gt, ":", 2) + if len(sp) != 2 { + return fmt.Errorf("invalid go tag: %s", v) + } + vv, err := strconv.Unquote(sp[1]) + if err != nil { + return fmt.Errorf("invalid go.tag value: %s, err: %v", sp[1], err.Error()) + } + key := sp[0] + tags.Remove(key) + tags = append(tags, model.Tag{ + Key: key, + Value: vv, + }) + } + } + } + + sort.Sort(tags) + gf.Tags = tags + return nil +} + +func getJsonValue(f *parser.Field, val string) string { + if v := getAnnotation(f.Annotations, AnnotationJsConv); len(v) > 0 { + val += ",string" + } + if !unsetOmitempty && f.Requiredness == parser.FieldType_Optional { + val += ",omitempty" + } else if f.Requiredness == parser.FieldType_Required { + val += ",required" + } + + return val +} + +func checkRequire(f *parser.Field, val string) string { + if f.Requiredness == parser.FieldType_Required { + val += ",required" + } + + return val +} + +// checkGoTag removes the tag defined in gotag +func checkGoTag(as parser.Annotations, tags *model.Tags) error { + rets := getAnnotation(as, AnnotationGoTag) + for _, v := range rets { + gts := util.SplitGoTags(v) + for _, gt := range gts { + sp := strings.SplitN(gt, ":", 2) + if len(sp) != 2 { + return fmt.Errorf("invalid go tag: %s", v) + } + key := sp[0] + tags.Remove(key) + } + } + + return nil +} diff --git a/thrift/test_data/test_tag.thrift b/thrift/test_data/test_tag.thrift new file mode 100644 index 0000000..ab3c56a --- /dev/null +++ b/thrift/test_data/test_tag.thrift @@ -0,0 +1,26 @@ +namespace go cloudwego.hertz.hz + +struct MultiTagReq { + // basic feature + 1: string DefaultQueryTag (api.query="query"); + 2: string RawBodyTag (api.raw_body="raw_body"); + 3: string PathTag (api.path="path"); + 4: string FormTag (api.form="form"); + 5: string CookieTag (api.cookie="cookie"); + 6: string HeaderTag (api.header="header"); + 7: string BodyTag (api.body="body"); + 8: string GoTag (go.tag="json:\"json\" query:\"query\" form:\"form\" header:\"header\" goTag:\"tag\""); + 9: string VdTag (api.vd="$!='?'"); + 10: string DefaultTag; + + // optional / required + 11: required string ReqQuery (api.query="query"); + 12: optional string OptQuery (api.query="query"); + 13: required string ReqBody (api.body="body"); + 14: optional string OptBody (api.body="body"); + 15: required string ReqGoTag (go.tag="json:\"json\""); + 16: optional string OptGoTag (go.tag="json:\"json\""); + + // gotag cover feature + 17: required string QueryGoTag (apt.query="query", go.tag="query:\"queryTag\"") +} \ No newline at end of file diff --git a/thrift/test_data/thrift_tag_test.out b/thrift/test_data/thrift_tag_test.out new file mode 100644 index 0000000..4e79681 Binary files /dev/null and b/thrift/test_data/thrift_tag_test.out differ diff --git a/thrift/thriftgo_util.go b/thrift/thriftgo_util.go new file mode 100644 index 0000000..0be9baa --- /dev/null +++ b/thrift/thriftgo_util.go @@ -0,0 +1,26 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "github.com/cloudwego/thriftgo/generator/golang" + "github.com/cloudwego/thriftgo/generator/golang/styles" +) + +var thriftgoUtil *golang.CodeUtils + +var NameStyle = styles.NewNamingStyle("thriftgo") diff --git a/util/ast.go b/util/ast.go new file mode 100644 index 0000000..b97d773 --- /dev/null +++ b/util/ast.go @@ -0,0 +1,65 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package util + +import ( + "bytes" + "fmt" + "go/ast" + "go/format" + "go/parser" + "go/token" + "path/filepath" + + "golang.org/x/tools/go/ast/astutil" +) + +func AddImport(file, alias, impt string) ([]byte, error) { + fset := token.NewFileSet() + path, _ := filepath.Abs(file) + f, err := parser.ParseFile(fset, path, nil, parser.ParseComments) + if err != nil { + return nil, fmt.Errorf("can not parse ast for file: %s, err: %v", path, err) + } + + return addImport(fset, f, alias, impt) +} + +func AddImportForContent(fileContent []byte, alias, impt string) ([]byte, error) { + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "", fileContent, parser.ParseComments) + if err != nil { + return nil, fmt.Errorf("can not parse ast for file: %s, err: %v", fileContent, err) + } + + return addImport(fset, f, alias, impt) +} + +func addImport(fset *token.FileSet, f *ast.File, alias, impt string) ([]byte, error) { + added := astutil.AddNamedImport(fset, f, alias, impt) + if !added { + return nil, fmt.Errorf("can not add import \"%s\" for file: %s", impt, f.Name.Name) + } + var output []byte + buffer := bytes.NewBuffer(output) + err := format.Node(buffer, fset, f) + if err != nil { + return nil, fmt.Errorf("can not add import for file: %s, err: %v", f.Name.Name, err) + } + + return buffer.Bytes(), nil +} diff --git a/util/ast_test.go b/util/ast_test.go new file mode 100644 index 0000000..8a7f54d --- /dev/null +++ b/util/ast_test.go @@ -0,0 +1,93 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package util + +import ( + "bytes" + "go/format" + "go/parser" + "go/token" + "testing" + + "golang.org/x/tools/go/ast/astutil" +) + +func TestAddImport(t *testing.T) { + inserts := [][]string{ + { + "ctx", + "context", + }, + { + "", + "context", + }, + } + files := [][]string{ + { + `package foo + +import ( + "fmt" + "time" +) +`, + `package foo + +import ( + ctx "context" + "fmt" + "time" +) +`, + }, + { + `package foo + +import ( + "fmt" + "time" +) +`, + `package foo + +import ( + "context" + "fmt" + "time" +) +`, + }, + } + for idx, file := range files { + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "", file[0], parser.ImportsOnly) + if err != nil { + t.Fatalf("can not parse ast for file") + } + astutil.AddNamedImport(fset, f, inserts[idx][0], inserts[idx][1]) + var output []byte + buffer := bytes.NewBuffer(output) + err = format.Node(buffer, fset, f) + if err != nil { + t.Fatalf("can add import for file") + } + if buffer.String() != file[1] { + t.Fatalf("insert import fialed") + } + } +} diff --git a/util/data.go b/util/data.go new file mode 100644 index 0000000..52c3674 --- /dev/null +++ b/util/data.go @@ -0,0 +1,436 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package util + +import ( + "errors" + "fmt" + "net/url" + "path/filepath" + "reflect" + "regexp" + "strconv" + "strings" + + "github.com/cloudwego/hertz/cmd/hz/util/logs" +) + +func CopyStringSlice(from, to *[]string) { + n := len(*from) + m := len(*to) + if n > m { + n = m + } + for i := 0; i < n; i++ { + (*to)[i] = (*from)[i] + } + *to = (*to)[:n] +} + +func CopyString2StringMap(from, to map[string]string) { + for k := range to { + delete(to, k) + } + for k, v := range from { + to[k] = v + } +} + +func PackArgs(c interface{}) (res []string, err error) { + t := reflect.TypeOf(c) + v := reflect.ValueOf(c) + if reflect.TypeOf(c).Kind() == reflect.Ptr { + t = t.Elem() + v = v.Elem() + } + if t.Kind() != reflect.Struct { + return nil, errors.New("passed c must be struct or pointer of struct") + } + + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + x := v.Field(i) + n := f.Name + + if x.IsZero() { + continue + } + + switch x.Kind() { + case reflect.Bool: + if x.Bool() == false { + continue + } + res = append(res, n+"="+fmt.Sprint(x.Bool())) + case reflect.String: + if x.String() == "" { + continue + } + res = append(res, n+"="+x.String()) + case reflect.Slice: + if x.Len() == 0 { + continue + } + ft := f.Type.Elem() + if ft.Kind() != reflect.String { + return nil, fmt.Errorf("slice field %v must be '[]string', err: %v", f.Name, err.Error()) + } + var ss []string + for i := 0; i < x.Len(); i++ { + ss = append(ss, x.Index(i).String()) + } + res = append(res, n+"="+strings.Join(ss, ";")) + case reflect.Map: + if x.Len() == 0 { + continue + } + fk := f.Type.Key() + if fk.Kind() != reflect.String { + return nil, fmt.Errorf("map field %v must be 'map[string]string', err: %v", f.Name, err.Error()) + } + fv := f.Type.Elem() + if fv.Kind() != reflect.String { + return nil, fmt.Errorf("map field %v must be 'map[string]string', err: %v", f.Name, err.Error()) + } + var sk []string + it := x.MapRange() + for it.Next() { + sk = append(sk, it.Key().String()+"="+it.Value().String()) + } + res = append(res, n+"="+strings.Join(sk, ";")) + default: + return nil, fmt.Errorf("unsupported field type: %+v, err: %v", f, err.Error()) + } + } + return res, nil +} + +func UnpackArgs(args []string, c interface{}) error { + m, err := MapForm(args) + if err != nil { + return fmt.Errorf("unmarshal args failed, err: %v", err.Error()) + } + + t := reflect.TypeOf(c).Elem() + v := reflect.ValueOf(c).Elem() + if t.Kind() != reflect.Struct { + return errors.New("passed c must be struct or pointer of struct") + } + + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + x := v.Field(i) + n := f.Name + values, ok := m[n] + if !ok || len(values) == 0 || values[0] == "" { + continue + } + switch x.Kind() { + case reflect.Bool: + if len(values) != 1 { + return fmt.Errorf("field %s can't be assigned multi values: %v", n, values) + } + x.SetBool(values[0] == "true") + case reflect.String: + if len(values) != 1 { + return fmt.Errorf("field %s can't be assigned multi values: %v", n, values) + } + x.SetString(values[0]) + case reflect.Slice: + if len(values) != 1 { + return fmt.Errorf("field %s can't be assigned multi values: %v", n, values) + } + ss := strings.Split(values[0], ";") + if x.Type().Elem().Kind() == reflect.Int { + n := reflect.MakeSlice(x.Type(), len(ss), len(ss)) + for i, s := range ss { + val, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return err + } + n.Index(i).SetInt(val) + } + x.Set(n) + } else { + for _, s := range ss { + val := reflect.Append(x, reflect.ValueOf(s)) + x.Set(val) + } + } + case reflect.Map: + if len(values) != 1 { + return fmt.Errorf("field %s can't be assigned multi values: %v", n, values) + } + ss := strings.Split(values[0], ";") + out := make(map[string]string, len(ss)) + for _, s := range ss { + sk := strings.SplitN(s, "=", 2) + if len(sk) != 2 { + return fmt.Errorf("map filed %v invalid key-value pair '%v'", n, s) + } + out[sk[0]] = sk[1] + } + x.Set(reflect.ValueOf(out)) + default: + return fmt.Errorf("field %s has unsupported type %+v", n, f.Type) + } + } + return nil +} + +func MapForm(input []string) (map[string][]string, error) { + out := make(map[string][]string, len(input)) + + for _, str := range input { + parts := strings.SplitN(str, "=", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid argument: '%s'", str) + } + key, val := parts[0], parts[1] + out[key] = append(out[key], val) + } + + return out, nil +} + +func GetFirstKV(m map[string][]string) (string, []string) { + for k, v := range m { + return k, v + } + return "", nil +} + +func ToCamelCase(name string) string { + return CamelString(name) +} + +func ToSnakeCase(name string) string { + return SnakeString(name) +} + +// unifyPath will convert "\" to "/" in path if the os is windows +func unifyPath(path string) string { + if IsWindows() { + path = strings.ReplaceAll(path, "\\", "/") + } + return path +} + +// BaseName get base name for path. ex: "github.com/p.s.m" => "p.s.m" +func BaseName(include, subFixToTrim string) string { + include = unifyPath(include) + subFixToTrim = unifyPath(subFixToTrim) + last := include + if id := strings.LastIndex(last, "/"); id >= 0 && id < len(last)-1 { + last = last[id+1:] + } + if !strings.HasSuffix(last, subFixToTrim) { + return last + } + return last[:len(last)-len(subFixToTrim)] +} + +func BaseNameAndTrim(include string) string { + include = unifyPath(include) + last := include + if id := strings.LastIndex(last, "/"); id >= 0 && id < len(last)-1 { + last = last[id+1:] + } + + if id := strings.LastIndex(last, "."); id != -1 { + last = last[:id] + } + return last +} + +func SplitPackageName(pkg, subFixToTrim string) string { + pkg = unifyPath(pkg) + subFixToTrim = unifyPath(subFixToTrim) + last := SplitPackage(pkg, subFixToTrim) + if id := strings.LastIndex(last, "/"); id >= 0 && id < len(last)-1 { + last = last[id+1:] + } + return last +} + +func SplitPackage(pkg, subFixToTrim string) string { + pkg = unifyPath(pkg) + subFixToTrim = unifyPath(subFixToTrim) + last := strings.TrimSuffix(pkg, subFixToTrim) + if id := strings.LastIndex(last, "/"); id >= 0 && id < len(last)-1 { + last = last[id+1:] + } + return strings.ReplaceAll(last, ".", "/") +} + +func PathToImport(path, subFix string) string { + path = strings.TrimSuffix(path, subFix) + // path = RelativePath(path) + return strings.ReplaceAll(path, string(filepath.Separator), "/") +} + +func ImportToPath(path, subFix string) string { + // path = RelativePath(path) + return strings.ReplaceAll(path, "/", string(filepath.Separator)) + subFix +} + +func ImportToPathAndConcat(path, subFix string) string { + path = strings.TrimSuffix(path, subFix) + path = strings.ReplaceAll(path, "/", string(filepath.Separator)) + if i := strings.LastIndex(path, string(filepath.Separator)); i >= 0 && i < len(path)-1 && strings.Contains(path[i+1:], ".") { + base := strings.ReplaceAll(path[i+1:], ".", "_") + dir := path[:i] + return dir + string(filepath.Separator) + base + } + return path +} + +func ToVarName(paths []string) string { + ps := strings.Join(paths, "__") + input := []byte(url.PathEscape(ps)) + out := make([]byte, 0, len(input)) + for i := 0; i < len(input); i++ { + c := input[i] + if c == ':' || c == '*' { + continue + } + if (c >= '0' && c <= '9' && i != 0) || (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c == '_') { + out = append(out, c) + } else { + out = append(out, '_') + } + } + + return string(out) +} + +func SplitGoTags(input string) []string { + out := make([]string, 0, 4) + ns := len(input) + + flag := false + prev := 0 + i := 0 + for i = 0; i < ns; i++ { + c := input[i] + if c == '"' { + flag = !flag + } + if !flag && c == ' ' { + if prev < i { + out = append(out, input[prev:i]) + } + prev = i + 1 + } + } + if i != 0 && prev < i { + out = append(out, input[prev:i]) + } + + return out +} + +func SubPackage(mod, dir string) string { + if dir == "" { + return mod + } + return mod + "/" + PathToImport(dir, "") +} + +func SubDir(root, subPkg string) string { + if root == "" { + return ImportToPath(subPkg, "") + } + return filepath.Join(root, ImportToPath(subPkg, "")) +} + +var ( + uniquePackageName = map[string]bool{} + uniqueMiddlewareName = map[string]bool{} + uniqueHandlerPackageName = map[string]bool{} +) + +// GetPackageUniqueName can get a non-repeating variable name for package alias +func GetPackageUniqueName(name string) (string, error) { + name, err := getUniqueName(name, uniquePackageName) + if err != nil { + return "", fmt.Errorf("can not generate unique name for package '%s', err: %v", name, err) + } + + return name, nil +} + +// GetMiddlewareUniqueName can get a non-repeating variable name for middleware name +func GetMiddlewareUniqueName(name string) (string, error) { + name, err := getUniqueName(name, uniqueMiddlewareName) + if err != nil { + return "", fmt.Errorf("can not generate routing group for path '%s', err: %v", name, err) + } + + return name, nil +} + +func GetHandlerPackageUniqueName(name string) (string, error) { + name, err := getUniqueName(name, uniqueHandlerPackageName) + if err != nil { + return "", fmt.Errorf("can not generate unique handler package name: '%s', err: %v", name, err) + } + + return name, nil +} + +// getUniqueName can get a non-repeating variable name +func getUniqueName(name string, uniqueNameSet map[string]bool) (string, error) { + uniqueName := name + if _, exist := uniqueNameSet[uniqueName]; exist { + for i := 0; i < 10000; i++ { + uniqueName = uniqueName + fmt.Sprintf("%d", i) + if _, exist := uniqueNameSet[uniqueName]; !exist { + logs.Infof("There is a package name with the same name, change %s to %s", name, uniqueName) + break + } + uniqueName = name + if i == 9999 { + return "", fmt.Errorf("there is too many same package for %s", name) + } + } + } + uniqueNameSet[uniqueName] = true + + return uniqueName, nil +} + +func SubPackageDir(path string) string { + index := strings.LastIndex(path, "/") + if index == -1 { + return "" + } + return path[:index] +} + +var validFuncReg = regexp.MustCompile("[_0-9a-zA-Z]") + +// ToGoFuncName converts a string to a function naming style for go +func ToGoFuncName(s string) string { + ss := []byte(s) + for i := range ss { + if !validFuncReg.Match([]byte{s[i]}) { + ss[i] = '_' + } + } + return string(ss) +} diff --git a/util/data_test.go b/util/data_test.go new file mode 100644 index 0000000..22ec51f --- /dev/null +++ b/util/data_test.go @@ -0,0 +1,78 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package util + +import "testing" + +func TestUniqueName(t *testing.T) { + type UniqueName struct { + Name string + ExpectedName string + ActualName string + } + + nameList := []UniqueName{ + { + Name: "aaa", + ExpectedName: "aaa", + }, + { + Name: "aaa", + ExpectedName: "aaa0", + }, + { + Name: "aaa0", + ExpectedName: "aaa00", + }, + { + Name: "aaa0", + ExpectedName: "aaa01", + }, + { + Name: "aaa00", + ExpectedName: "aaa000", + }, + { + Name: "aaa", + ExpectedName: "aaa1", + }, + { + Name: "aaa", + ExpectedName: "aaa2", + }, + { + Name: "aaa", + ExpectedName: "aaa3", + }, + { + Name: "aaa", + ExpectedName: "aaa4", + }, + } + for _, name := range nameList { + name.ActualName, _ = getUniqueName(name.Name, uniquePackageName) + if name.ActualName != name.ExpectedName { + t.Errorf("%s name expected unique name '%s', actually get '%s'", name.Name, name.ExpectedName, name.ActualName) + } + } + for _, name := range nameList { + name.ActualName, _ = getUniqueName(name.Name, uniqueMiddlewareName) + if name.ActualName != name.ExpectedName { + t.Errorf("%s name expected unique name '%s', actually get '%s'", name.Name, name.ExpectedName, name.ActualName) + } + } +} diff --git a/util/env.go b/util/env.go new file mode 100644 index 0000000..65289a6 --- /dev/null +++ b/util/env.go @@ -0,0 +1,138 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package util + +import ( + "bytes" + "fmt" + "go/build" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "regexp" + "strings" + + "github.com/cloudwego/hertz/cmd/hz/meta" +) + +func GetGOPATH() (gopath string, err error) { + ps := filepath.SplitList(os.Getenv("GOPATH")) + if len(ps) > 0 { + gopath = ps[0] + } + if gopath == "" { + cmd := exec.Command("go", "env", "GOPATH") + var out bytes.Buffer + cmd.Stderr = &out + cmd.Stdout = &out + if err := cmd.Run(); err == nil { + gopath = strings.Trim(out.String(), " \t\n\r") + } + } + if gopath == "" { + ps := GetBuildGoPaths() + if len(ps) > 0 { + gopath = ps[0] + } + } + isExist, err := PathExist(gopath) + if !isExist { + return "", err + } + return strings.Replace(gopath, "/", string(os.PathSeparator), -1), nil +} + +// GetBuildGoPaths returns the list of Go path directories. +func GetBuildGoPaths() []string { + var all []string + for _, p := range filepath.SplitList(build.Default.GOPATH) { + if p == "" || p == build.Default.GOROOT { + continue + } + if strings.HasPrefix(p, "~") { + continue + } + all = append(all, p) + } + for k, v := range all { + if strings.HasSuffix(v, "/") || strings.HasSuffix(v, string(os.PathSeparator)) { + v = v[:len(v)-1] + } + all[k] = v + } + return all +} + +var goModReg = regexp.MustCompile(`^\s*module\s+(\S+)\s*`) + +// SearchGoMod searches go.mod from the given directory (which must be an absolute path) to +// the root directory. When the go.mod is found, its module name and path will be returned. +func SearchGoMod(cwd string, recurse bool) (moduleName, path string, found bool) { + for { + path = filepath.Join(cwd, "go.mod") + data, err := ioutil.ReadFile(path) + if err == nil { + for _, line := range strings.Split(string(data), "\n") { + m := goModReg.FindStringSubmatch(line) + if m != nil { + return m[1], cwd, true + } + } + return fmt.Sprintf("", path), path, true + } + + if !os.IsNotExist(err) { + return + } + if !recurse { + break + } + cwd = filepath.Dir(cwd) + // the root directory will return itself by using "filepath.Dir()"; to prevent dead loops, so jump out + if cwd == filepath.Dir(cwd) { + break + } + } + return +} + +func InitGoMod(module string) error { + isExist, err := PathExist("go.mod") + if err != nil { + return err + } + if isExist { + return nil + } + gg, err := exec.LookPath("go") + if err != nil { + return err + } + cmd := &exec.Cmd{ + Path: gg, + Args: []string{"go", "mod", "init", module}, + Stdin: os.Stdin, + Stdout: os.Stdout, + Stderr: os.Stderr, + } + return cmd.Run() +} + +func IsWindows() bool { + return meta.SysType == meta.WindowsOS +} diff --git a/util/fs.go b/util/fs.go new file mode 100644 index 0000000..e52aea2 --- /dev/null +++ b/util/fs.go @@ -0,0 +1,47 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package util + +import ( + "os" + "path/filepath" +) + +func PathExist(path string) (bool, error) { + abPath, err := filepath.Abs(path) + if err != nil { + return false, err + } + _, err = os.Stat(abPath) + if err != nil { + return os.IsExist(err), nil + } + return true, nil +} + +func RelativePath(path string) (string, error) { + path, err := filepath.Abs(path) + if err != nil { + return "", err + } + cwd, err := os.Getwd() + if err != nil { + return "", err + } + ret, _ := filepath.Rel(cwd, path) + return ret, nil +} diff --git a/util/logs/api.go b/util/logs/api.go new file mode 100644 index 0000000..09f86e5 --- /dev/null +++ b/util/logs/api.go @@ -0,0 +1,84 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logs + +func init() { + defaultLogger = NewStdLogger(LevelInfo) +} + +func SetLogger(logger Logger) { + defaultLogger = logger +} + +const ( + LevelDebug = 1 + iota + LevelInfo + LevelWarn + LevelError +) + +// TODO: merge with hertz logger package +type Logger interface { + Debugf(format string, v ...interface{}) + Infof(format string, v ...interface{}) + Warnf(format string, v ...interface{}) + Errorf(format string, v ...interface{}) + Flush() + SetLevel(level int) error +} + +var defaultLogger Logger + +func Errorf(format string, v ...interface{}) { + defaultLogger.Errorf(format, v...) +} + +func Warnf(format string, v ...interface{}) { + defaultLogger.Warnf(format, v...) +} + +func Infof(format string, v ...interface{}) { + defaultLogger.Infof(format, v...) +} + +func Debugf(format string, v ...interface{}) { + defaultLogger.Debugf(format, v...) +} + +func Error(format string, v ...interface{}) { + defaultLogger.Errorf(format, v...) +} + +func Warn(format string, v ...interface{}) { + defaultLogger.Warnf(format, v...) +} + +func Info(format string, v ...interface{}) { + defaultLogger.Infof(format, v...) +} + +func Debug(format string, v ...interface{}) { + defaultLogger.Debugf(format, v...) +} + +func Flush() { + defaultLogger.Flush() +} + +func SetLevel(level int) { + defaultLogger.SetLevel(level) +} diff --git a/util/logs/std.go b/util/logs/std.go new file mode 100644 index 0000000..3054aa1 --- /dev/null +++ b/util/logs/std.go @@ -0,0 +1,141 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logs + +import ( + "bytes" + "errors" + "fmt" + "log" + "os" +) + +type StdLogger struct { + level int + outLogger *log.Logger + warnLogger *log.Logger + errLogger *log.Logger + out *bytes.Buffer + warn *bytes.Buffer + err *bytes.Buffer + Defer bool + ErrOnly bool +} + +func NewStdLogger(level int) *StdLogger { + out := bytes.NewBuffer(nil) + warn := bytes.NewBuffer(nil) + err := bytes.NewBuffer(nil) + return &StdLogger{ + level: level, + outLogger: log.New(out, "[INFO]", log.Llongfile), + warnLogger: log.New(warn, "[WARN]", log.Llongfile), + errLogger: log.New(err, "[ERROR]", log.Llongfile), + out: out, + warn: warn, + err: err, + } +} + +func (stdLogger *StdLogger) Debugf(format string, v ...interface{}) { + if stdLogger.level > LevelDebug { + return + } + stdLogger.outLogger.Output(3, fmt.Sprintf(format, v...)) + if !stdLogger.Defer { + stdLogger.FlushOut() + } +} + +func (stdLogger *StdLogger) Infof(format string, v ...interface{}) { + if stdLogger.level > LevelInfo { + return + } + stdLogger.outLogger.Output(3, fmt.Sprintf(format, v...)) + if !stdLogger.Defer { + stdLogger.FlushOut() + } +} + +func (stdLogger *StdLogger) Warnf(format string, v ...interface{}) { + if stdLogger.level > LevelWarn { + return + } + stdLogger.warnLogger.Output(3, fmt.Sprintf(format, v...)) + if !stdLogger.Defer { + stdLogger.FlushErr() + } +} + +func (stdLogger *StdLogger) Errorf(format string, v ...interface{}) { + if stdLogger.level > LevelError { + return + } + stdLogger.errLogger.Output(3, fmt.Sprintf(format, v...)) + if !stdLogger.Defer { + stdLogger.FlushErr() + } +} + +func (stdLogger *StdLogger) Flush() { + stdLogger.FlushErr() + if !stdLogger.ErrOnly { + stdLogger.FlushOut() + } +} + +func (stdLogger *StdLogger) FlushOut() { + os.Stderr.Write(stdLogger.out.Bytes()) + stdLogger.out.Reset() +} + +func (stdLogger *StdLogger) Err() string { + return string(stdLogger.err.Bytes()) +} + +func (stdLogger *StdLogger) Warn() string { + return string(stdLogger.warn.Bytes()) +} + +func (stdLogger *StdLogger) FlushErr() { + os.Stderr.Write(stdLogger.err.Bytes()) + stdLogger.err.Reset() +} + +func (stdLogger *StdLogger) OutLines() []string { + lines := bytes.Split(stdLogger.out.Bytes(), []byte("[INFO]")) + var rets []string + for _, line := range lines { + rets = append(rets, string(line)) + } + return rets +} + +func (stdLogger *StdLogger) Out() []byte { + return stdLogger.out.Bytes() +} + +func (stdLogger *StdLogger) SetLevel(level int) error { + switch level { + case LevelDebug, LevelInfo, LevelWarn, LevelError: + break + default: + return errors.New("invalid log level") + } + stdLogger.level = level + return nil +} diff --git a/util/string.go b/util/string.go new file mode 100644 index 0000000..c317753 --- /dev/null +++ b/util/string.go @@ -0,0 +1,99 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package util + +import ( + "reflect" + "strings" + "unicode/utf8" + "unsafe" +) + +func Str2Bytes(in string) (out []byte) { + op := (*reflect.SliceHeader)(unsafe.Pointer(&out)) + ip := (*reflect.StringHeader)(unsafe.Pointer(&in)) + op.Data = ip.Data + op.Cap = ip.Len + op.Len = ip.Len + return +} + +func Bytes2Str(in []byte) (out string) { + op := (*reflect.StringHeader)(unsafe.Pointer(&out)) + ip := (*reflect.SliceHeader)(unsafe.Pointer(&in)) + op.Data = ip.Data + op.Len = ip.Len + return +} + +// TrimLastChar can remove the last char for s +func TrimLastChar(s string) string { + r, size := utf8.DecodeLastRuneInString(s) + if r == utf8.RuneError && (size == 0 || size == 1) { + size = 0 + } + return s[:len(s)-size] +} + +// AddSlashForComments can adjust the format of multi-line comments +func AddSlashForComments(s string) string { + s = strings.Replace(s, "\n", "\n//", -1) + return s +} + +// CamelString converts the string 's' to a camel string +func CamelString(s string) string { + data := make([]byte, 0, len(s)) + j := false + k := false + num := len(s) - 1 + for i := 0; i <= num; i++ { + d := s[i] + if k == false && d >= 'A' && d <= 'Z' { + k = true + } + if d >= 'a' && d <= 'z' && (j || k == false) { + d = d - 32 + j = false + k = true + } + if k && d == '_' && num > i && s[i+1] >= 'a' && s[i+1] <= 'z' { + j = true + continue + } + data = append(data, d) + } + return Bytes2Str(data[:]) +} + +// SnakeString converts the string 's' to a snake string +func SnakeString(s string) string { + data := make([]byte, 0, len(s)*2) + j := false + for _, d := range Str2Bytes(s) { + if d >= 'A' && d <= 'Z' { + if j { + data = append(data, '_') + j = false + } + } else if d != '_' { + j = true + } + data = append(data, d) + } + return strings.ToLower(Bytes2Str(data)) +} diff --git a/util/tool_install.go b/util/tool_install.go new file mode 100644 index 0000000..1081d4f --- /dev/null +++ b/util/tool_install.go @@ -0,0 +1,151 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package util + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "time" + + "github.com/cloudwego/hertz/cmd/hz/meta" + "github.com/cloudwego/hertz/cmd/hz/util/logs" + gv "github.com/hashicorp/go-version" +) + +const ThriftgoMiniVersion = "v0.2.0" + +// QueryVersion will query the version of the corresponding executable. +func QueryVersion(exe string) (version string, err error) { + var buf strings.Builder + cmd := &exec.Cmd{ + Path: exe, + Args: []string{ + exe, "--version", + }, + Stdin: os.Stdin, + Stdout: &buf, + Stderr: &buf, + } + err = cmd.Run() + if err == nil { + version = strings.Split(buf.String(), " ")[1] + if strings.HasSuffix(version, "\n") { + version = version[:len(version)-1] + } + } + return +} + +// ShouldUpdate will return "true" when current is lower than latest. +func ShouldUpdate(current, latest string) bool { + cv, err := gv.NewVersion(current) + if err != nil { + return false + } + lv, err := gv.NewVersion(latest) + if err != nil { + return false + } + + return cv.Compare(lv) < 0 +} + +// InstallAndCheckThriftgo will automatically install thriftgo and judge whether it is installed successfully. +func InstallAndCheckThriftgo() error { + exe, err := exec.LookPath("go") + if err != nil { + return fmt.Errorf("can not find tool 'go': %v", err) + } + var buf strings.Builder + cmd := &exec.Cmd{ + Path: exe, + Args: []string{ + exe, "install", "github.com/cloudwego/thriftgo@latest", + }, + Stdin: os.Stdin, + Stdout: &buf, + Stderr: &buf, + } + + done := make(chan error) + logs.Infof("installing thriftgo automatically") + go func() { + done <- cmd.Run() + }() + select { + case err = <-done: + if err != nil { + return fmt.Errorf("can not install thriftgo, err: %v. Please install it manual, and make sure the version of thriftgo is greater than v0.2.0", cmd.Stderr) + } + case <-time.After(time.Second * 30): + return fmt.Errorf("install thriftgo time out.Please install it manual, and make sure the version of thriftgo is greater than v0.2.0") + } + + exist, err := CheckCompiler(meta.TpCompilerThrift) + if err != nil { + return fmt.Errorf("check %s exist failed, err: %v", meta.TpCompilerThrift, err) + } + if !exist { + return fmt.Errorf("install thriftgo failed. Please install it manual, and make sure the version of thriftgo is greater than v0.2.0") + } + + return nil +} + +// CheckCompiler will check if the tool exists. +func CheckCompiler(tool string) (bool, error) { + path, err := exec.LookPath(tool) + if err != nil { + goPath, err := GetGOPATH() + if err != nil { + return false, fmt.Errorf("get 'GOPATH' failed for find %s : %v", tool, path) + } + path = filepath.Join(goPath, "bin", tool) + } + + isExist, err := PathExist(path) + if err != nil { + return false, fmt.Errorf("can not check %s exist, err: %v", tool, err) + } + if !isExist { + return false, nil + } + + return true, nil +} + +// CheckAndUpdateThriftgo checks the version of thriftgo and updates the tool to the latest version if its version is less than v0.2.0. +func CheckAndUpdateThriftgo() error { + path, err := exec.LookPath(meta.TpCompilerThrift) + if err != nil { + return fmt.Errorf("can not find %s", meta.TpCompilerThrift) + } + curVersion, err := QueryVersion(path) + logs.Infof("current thriftgo version is %s", curVersion) + if ShouldUpdate(curVersion, ThriftgoMiniVersion) { + logs.Infof(" current thriftgo version is less than v0.2.0, so update thriftgo version") + err = InstallAndCheckThriftgo() + if err != nil { + return fmt.Errorf("update thriftgo version failed, err: %v", err) + } + } + + return nil +} diff --git a/util/tool_install_test.go b/util/tool_install_test.go new file mode 100644 index 0000000..946befd --- /dev/null +++ b/util/tool_install_test.go @@ -0,0 +1,36 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package util + +import "testing" + +func TestQueryVersion(t *testing.T) { + lowVersion := "v0.1.0" + equalVersion := "v0.2.0" + highVersion := "v0.3.0" + + if ShouldUpdate(lowVersion, ThriftgoMiniVersion) { + } + + if ShouldUpdate(equalVersion, ThriftgoMiniVersion) { + t.Fatal("should not be updated") + } + + if ShouldUpdate(highVersion, ThriftgoMiniVersion) { + t.Fatal("should not be updated") + } +}