gr_hz/protobuf/plugin.go
huangqz a407e10ebe 修复批量 update 多个 idl 时只生成最后一个 handler/router 的问题
原 protobuf 插件只取 FileToGenerate 的最后一个 idl 生成 http package,导致 hz_gen
批量脚本一次传入多个 --idl 时,除最后一个外其余 proto 只生成 model、不生成 handler/router。

- protobuf/plugin.go: Handle() 改为遍历全部 FileToGenerate;每个 idl 生成后立即写盘,
  使后续 idl 的 register.go/middleware.go/handler 合并能读到累积内容,保证多个新服务
  同批生成不互相覆盖注册。
- util/data.go: 新增 ResetUniqueNameSets()
- generator/router.go: 新增 ResetRouterState()
  每个 idl 生成前重置进程级命名状态,使单进程批量输出与逐文件生成逐字节一致,
  避免中间件分组变量名带全局序号后缀(_v1->_v184)造成大面积无谓 diff。
2026-06-16 18:26:57 +08:00

654 lines
19 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*
* Copyright 2022 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* Copyright (c) 2018 The Go Authors. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* This file may have been modified by CloudWeGo authors. All CloudWeGo
* Modifications are Copyright 2022 CloudWeGo Authors.
*/
package protobuf
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"strings"
"golib.gaore.com/GaoreGo/gr_hz/config"
"golib.gaore.com/GaoreGo/gr_hz/generator"
"golib.gaore.com/GaoreGo/gr_hz/generator/model"
"golib.gaore.com/GaoreGo/gr_hz/meta"
"golib.gaore.com/GaoreGo/gr_hz/util"
"golib.gaore.com/GaoreGo/gr_hz/util/logs"
gengo "google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/runtime/protoimpl"
"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/pluginpb"
)
type Plugin struct {
*protogen.Plugin
Package string
Recursive bool
OutDir string
ModelDir string
UseDir string
IdlClientDir string
RmTags RemoveTags
PkgMap map[string]string
logger *logs.StdLogger
}
type RemoveTags []string
func (rm *RemoveTags) Exist(tag string) bool {
for _, rmTag := range *rm {
if rmTag == tag {
return true
}
}
return false
}
func (plugin *Plugin) Run() int {
plugin.setLogger()
args := &config.Argument{}
defer func() {
if args == nil {
return
}
if args.Verbose {
verboseLog := plugin.recvVerboseLogger()
if len(verboseLog) != 0 {
fmt.Fprintf(os.Stderr, verboseLog)
}
} else {
warning := plugin.recvWarningLogger()
if len(warning) != 0 {
fmt.Fprintf(os.Stderr, warning)
}
}
}()
// read protoc request
in, err := ioutil.ReadAll(os.Stdin)
if err != nil {
logs.Errorf("read request failed: %s\n", err.Error())
return meta.PluginError
}
req := &pluginpb.CodeGeneratorRequest{}
err = proto.Unmarshal(in, req)
if err != nil {
logs.Errorf("unmarshal request failed: %s\n", err.Error())
return meta.PluginError
}
args, err = plugin.parseArgs(*req.Parameter)
if err != nil {
logs.Errorf("parse args failed: %s\n", err.Error())
return meta.PluginError
}
CheckTagOption(args)
// generate
err = plugin.Handle(req, args)
if err != nil {
logs.Errorf("generate failed: %s\n", err.Error())
return meta.PluginError
}
return 0
}
func (plugin *Plugin) setLogger() {
plugin.logger = logs.NewStdLogger(logs.LevelInfo)
plugin.logger.Defer = true
plugin.logger.ErrOnly = true
logs.SetLogger(plugin.logger)
}
func (plugin *Plugin) recvWarningLogger() string {
warns := plugin.logger.Warn()
plugin.logger.Flush()
logs.SetLogger(logs.NewStdLogger(logs.LevelInfo))
return warns
}
func (plugin *Plugin) recvVerboseLogger() string {
info := plugin.logger.Out()
warns := plugin.logger.Warn()
verboseLog := string(info) + warns
plugin.logger.Flush()
logs.SetLogger(logs.NewStdLogger(logs.LevelInfo))
return verboseLog
}
func (plugin *Plugin) parseArgs(param string) (*config.Argument, error) {
args := new(config.Argument)
params := strings.Split(param, ",")
err := args.Unpack(params)
if err != nil {
return nil, err
}
plugin.Package, err = args.GetGoPackage()
if err != nil {
return nil, err
}
plugin.Recursive = !args.NoRecurse
plugin.ModelDir, err = args.GetModelDir()
if err != nil {
return nil, err
}
plugin.OutDir = args.OutDir
plugin.PkgMap = args.OptPkgMap
plugin.UseDir = args.Use
return args, nil
}
func (plugin *Plugin) Response(resp *pluginpb.CodeGeneratorResponse) error {
out, err := proto.Marshal(resp)
if err != nil {
return fmt.Errorf("marshal response failed: %s", err.Error())
}
_, err = os.Stdout.Write(out)
if err != nil {
return fmt.Errorf("write response failed: %s", err.Error())
}
return nil
}
func (plugin *Plugin) Handle(req *pluginpb.CodeGeneratorRequest, args *config.Argument) error {
plugin.fixGoPackage(req, plugin.PkgMap)
// new plugin
opts := protogen.Options{}
gen, err := opts.New(req)
plugin.Plugin = gen
plugin.RmTags = args.RmTags
if err != nil {
return fmt.Errorf("new protoc plugin failed: %s", err.Error())
}
// plugin start working
err = plugin.GenerateFiles(gen)
if err != nil {
// Error within the plugin will be responded by the plugin.
// But if the plugin does not response correctly, the error is returned to the upper level.
err := fmt.Errorf("generate model file failed: %s", err.Error())
gen.Error(err)
resp := gen.Response()
err2 := plugin.Response(resp)
if err2 != nil {
return err
}
return nil
}
if args.CmdType == meta.CmdModel {
resp := gen.Response()
// plugin stop working
err = plugin.Response(resp)
if err != nil {
return fmt.Errorf("write response failed: %s", err.Error())
}
return nil
}
files := gen.Request.ProtoFile
maps := make(map[string]*descriptorpb.FileDescriptorProto, len(files))
for _, file := range files {
maps[file.GetName()] = file
}
// 为 FileToGenerate 中的「每一个」idl 生成 http packagehandler/router/register/middleware
// 原实现只取最后一个 idl导致一次传入多个 --idl 时,除最后一个外其余 proto 都只生成
// model、不生成 handler/routerhz_gen 批量脚本正是踩了这个坑)。
//
// 关键点handler 追加、register.go、middleware.go 都是「读磁盘现有内容再合并」。插件进程内
// 默认不落盘(最终交给 protoc 写),因此若只在内存里循环,后一个 idl 读到的仍是旧文件,会丢掉
// 前一个 idl 新增的路由/中间件注册。所以每生成完一个 idl 就「立即写盘」,保证下一轮合并读到的是
// 累积后的最新内容。这样多个新服务同批生成也不会互相覆盖注册。
for _, idlName := range gen.Request.FileToGenerate {
// 每个 idl 开始前重置进程级命名状态,使其行为与「逐文件单进程」一致,
// 避免中间件分组变量名带上跨 idl 的全局序号后缀(如 _v184造成大面积无谓 diff。
util.ResetUniqueNameSets()
generator.ResetRouterState()
main := maps[idlName]
deps := make(map[string]*descriptorpb.FileDescriptorProto, len(main.GetDependency()))
for _, dep := range main.GetDependency() {
f, ok := maps[dep]
if !ok {
err := fmt.Errorf("dependency file not found: %s", dep)
gen.Error(err)
resp := gen.Response()
if err2 := plugin.Response(resp); err2 != nil {
return err
}
return nil
}
deps[dep] = f
}
pkgFiles, err := plugin.genHttpPackage(main, deps, args)
if err != nil {
err := fmt.Errorf("generate package files failed: %s", err.Error())
gen.Error(err)
resp := gen.Response()
if err2 := plugin.Response(resp); err2 != nil {
return err
}
return nil
}
// 立即落盘,使后续 idl 的 register.go/middleware.go/handler 合并读到最新内容。
for _, pkgFile := range pkgFiles {
absPath := pkgFile.Path
if !filepath.IsAbs(absPath) {
absPath = filepath.Join(args.OutDir, pkgFile.Path)
}
if err := os.MkdirAll(filepath.Dir(absPath), 0o755); err != nil {
return fmt.Errorf("create dir for '%s' failed: %s", absPath, err.Error())
}
if err := ioutil.WriteFile(absPath, []byte(pkgFile.Content), 0o644); err != nil {
return fmt.Errorf("write http package file '%s' failed: %s", absPath, err.Error())
}
}
}
// http package 已直接写盘;响应里只回模型文件(交给 protoc 写)。
resp := gen.Response()
if err := plugin.Response(resp); err != nil {
return fmt.Errorf("write response failed: %s", err.Error())
}
return nil
}
// fixGoPackage will update go_package to store all the model files in ${model_dir}
func (plugin *Plugin) fixGoPackage(req *pluginpb.CodeGeneratorRequest, pkgMap map[string]string) {
gopkg := plugin.Package
for _, f := range req.ProtoFile {
if strings.HasPrefix(f.GetPackage(), "google.protobuf") {
continue
}
opt := getGoPackage(f, pkgMap)
if !strings.Contains(opt, gopkg) {
if strings.HasPrefix(opt, "/") {
opt = gopkg + opt
} else {
opt = gopkg + "/" + opt
}
}
impt, _ := plugin.fixModelPathAndPackage(opt)
*f.Options.GoPackage = impt
}
}
// fixModelPathAndPackage will modify the go_package to adapt the go_package of the hz,
// for example adding the go module and model dir.
func (plugin *Plugin) fixModelPathAndPackage(pkg string) (impt, path string) {
if strings.HasPrefix(pkg, plugin.Package) {
impt = util.ImportToPathAndConcat(pkg[len(plugin.Package):], "")
}
if plugin.ModelDir != "" && plugin.ModelDir != "." {
modelImpt := util.PathToImport(string(filepath.Separator)+plugin.ModelDir, "")
// trim model dir for go package
if strings.HasPrefix(impt, modelImpt) {
impt = impt[len(modelImpt):]
}
impt = util.PathToImport(plugin.ModelDir, "") + impt
}
path = util.ImportToPath(impt, "")
impt = plugin.Package + "/" + impt
if util.IsWindows() {
impt = util.PathToImport(impt, "")
}
return
}
func (plugin *Plugin) GenerateFiles(pluginPb *protogen.Plugin) error {
idl := pluginPb.Request.FileToGenerate[len(pluginPb.Request.FileToGenerate)-1]
pluginPb.SupportedFeatures = gengo.SupportedFeatures
for _, f := range pluginPb.Files {
if f.Proto.GetName() == idl {
err := plugin.GenerateFile(pluginPb, f)
if err != nil {
return err
}
impt := string(f.GoImportPath)
if strings.HasPrefix(impt, plugin.Package) {
impt = impt[len(plugin.Package):]
}
plugin.IdlClientDir = impt
} else if plugin.Recursive {
if strings.HasPrefix(f.Proto.GetPackage(), "google.protobuf") {
continue
}
err := plugin.GenerateFile(pluginPb, f)
if err != nil {
return err
}
}
}
return nil
}
func (plugin *Plugin) GenerateFile(gen *protogen.Plugin, f *protogen.File) error {
impt := string(f.GoImportPath)
if strings.HasPrefix(impt, plugin.Package) {
impt = impt[len(plugin.Package):]
}
f.GeneratedFilenamePrefix = filepath.Join(util.ImportToPath(impt, ""), util.BaseName(f.Proto.GetName(), ".proto"))
f.Generate = true
// if use third-party model, no model code is generated within the project
if len(plugin.UseDir) != 0 {
return nil
}
file, err := generateFile(gen, f, plugin.RmTags)
if err != nil || file == nil {
return fmt.Errorf("generate file %s failed: %s", f.Proto.GetName(), err.Error())
}
return nil
}
// generateFile generates the contents of a .pb.go file.
func generateFile(gen *protogen.Plugin, file *protogen.File, rmTags RemoveTags) (*protogen.GeneratedFile, error) {
filename := file.GeneratedFilenamePrefix + ".pb.go"
g := gen.NewGeneratedFile(filename, file.GoImportPath)
f := newFileInfo(file)
genStandaloneComments(g, f, int32(FileDescriptorProto_Syntax_field_number))
genGeneratedHeader(gen, g, f)
genStandaloneComments(g, f, int32(FileDescriptorProto_Package_field_number))
packageDoc := genPackageKnownComment(f)
g.P(packageDoc, "package ", f.GoPackageName)
g.P()
// Emit a static check that enforces a minimum version of the proto package.
if gengo.GenerateVersionMarkers {
g.P("const (")
g.P("// Verify that this generated code is sufficiently up-to-date.")
g.P("_ = ", protoimplPackage.Ident("EnforceVersion"), "(", protoimpl.GenVersion, " - ", protoimplPackage.Ident("MinVersion"), ")")
g.P("// Verify that runtime/protoimpl is sufficiently up-to-date.")
g.P("_ = ", protoimplPackage.Ident("EnforceVersion"), "(", protoimplPackage.Ident("MaxVersion"), " - ", protoimpl.GenVersion, ")")
g.P(")")
g.P()
}
for i, imps := 0, f.Desc.Imports(); i < imps.Len(); i++ {
genImport(gen, g, f, imps.Get(i))
}
for _, enum := range f.allEnums {
genEnum(g, f, enum)
}
var err error
for _, message := range f.allMessages {
err = genMessage(g, f, message, rmTags)
if err != nil {
return nil, err
}
}
genExtensions(g, f)
genReflectFileDescriptor(gen, g, f)
return g, nil
}
func genMessage(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, rmTags RemoveTags) error {
if m.Desc.IsMapEntry() {
return nil
}
// Message type declaration.
g.Annotate(m.GoIdent.GoName, m.Location)
leadingComments := appendDeprecationSuffix(m.Comments.Leading,
m.Desc.Options().(*descriptorpb.MessageOptions).GetDeprecated())
g.P(leadingComments,
"type ", m.GoIdent, " struct {")
err := genMessageFields(g, f, m, rmTags)
if err != nil {
return err
}
g.P("}")
g.P()
genMessageKnownFunctions(g, f, m)
genMessageDefaultDecls(g, f, m)
genMessageMethods(g, f, m)
genMessageOneofWrapperTypes(g, f, m)
return nil
}
func genMessageFields(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, rmTags RemoveTags) error {
sf := f.allMessageFieldsByPtr[m]
genMessageInternalFields(g, f, m, sf)
var err error
for _, field := range m.Fields {
err = genMessageField(g, f, m, field, sf, rmTags)
if err != nil {
return err
}
}
return nil
}
func genMessageField(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, field *protogen.Field, sf *structFields, rmTags RemoveTags) error {
if oneof := field.Oneof; oneof != nil && !oneof.Desc.IsSynthetic() {
// It would be a bit simpler to iterate over the oneofs below,
// but generating the field here keeps the contents of the Go
// struct in the same order as the contents of the source
// .proto file.
if oneof.Fields[0] != field {
return nil // only generate for first appearance
}
tags := structTags{
{"protobuf_oneof", string(oneof.Desc.Name())},
}
if m.isTracked {
tags = append(tags, gotrackTags...)
}
g.Annotate(m.GoIdent.GoName+"."+oneof.GoName, oneof.Location)
leadingComments := oneof.Comments.Leading
if leadingComments != "" {
leadingComments += "\n"
}
ss := []string{fmt.Sprintf(" Types that are assignable to %s:\n", oneof.GoName)}
for _, field := range oneof.Fields {
ss = append(ss, "\t*"+field.GoIdent.GoName+"\n")
}
leadingComments += protogen.Comments(strings.Join(ss, ""))
g.P(leadingComments,
oneof.GoName, " ", oneofInterfaceName(oneof), tags)
sf.append(oneof.GoName)
return nil
}
goType, pointer := fieldGoType(g, f, field)
if pointer {
goType = "*" + goType
}
tags := structTags{
{"protobuf", fieldProtobufTagValue(field)},
//{"json", fieldJSONTagValue(field)},
}
if field.Desc.IsMap() {
key := field.Message.Fields[0]
val := field.Message.Fields[1]
tags = append(tags, structTags{
{"protobuf_key", fieldProtobufTagValue(key)},
{"protobuf_val", fieldProtobufTagValue(val)},
}...)
}
err := injectTagsToStructTags(field.Desc, &tags, true, rmTags)
if err != nil {
return err
}
if m.isTracked {
tags = append(tags, gotrackTags...)
}
name := field.GoName
if field.Desc.IsWeak() {
name = WeakFieldPrefix_goname + name
}
g.Annotate(m.GoIdent.GoName+"."+name, field.Location)
leadingComments := appendDeprecationSuffix(field.Comments.Leading,
field.Desc.Options().(*descriptorpb.FieldOptions).GetDeprecated())
g.P(leadingComments,
name, " ", goType, tags,
trailingComment(field.Comments.Trailing))
sf.append(field.GoName)
return nil
}
func (plugin *Plugin) getIdlInfo(ast *descriptorpb.FileDescriptorProto, deps map[string]*descriptorpb.FileDescriptorProto, args *config.Argument) (*generator.HttpPackage, error) {
if ast == nil {
return nil, fmt.Errorf("ast is nil")
}
pkg := getGoPackage(ast, map[string]string{})
main := &model.Model{
FilePath: ast.GetName(),
Package: pkg,
PackageName: util.BaseName(pkg, ""),
}
fileInfo := FileInfos{
Official: deps,
PbReflect: nil,
}
rs, err := NewResolver(ast, fileInfo, main, map[string]string{})
if err != nil {
return nil, fmt.Errorf("new protobuf resolver failed, err:%v", err)
}
err = rs.LoadAll(ast)
if err != nil {
return nil, err
}
services, err := astToService(ast, rs, args.CmdType, plugin.Plugin)
if err != nil {
return nil, err
}
var models model.Models
for _, s := range services {
models.MergeArray(s.Models)
}
return &generator.HttpPackage{
Services: services,
IdlName: ast.GetName(),
Package: util.BaseName(pkg, ""),
Models: models,
}, nil
}
func (plugin *Plugin) genHttpPackage(ast *descriptorpb.FileDescriptorProto, deps map[string]*descriptorpb.FileDescriptorProto, args *config.Argument) ([]generator.File, error) {
options := CheckTagOption(args)
idl, err := plugin.getIdlInfo(ast, deps, args)
if err != nil {
return nil, err
}
customPackageTemplate := args.CustomizePackage
pkg, err := args.GetGoPackage()
if err != nil {
return nil, err
}
handlerDir, err := args.GetHandlerDir()
if err != nil {
return nil, err
}
routerDir, err := args.GetRouterDir()
if err != nil {
return nil, err
}
modelDir, err := args.GetModelDir()
if err != nil {
return nil, err
}
clientDir, err := args.GetClientDir()
if err != nil {
return nil, err
}
sg := generator.HttpPackageGenerator{
ConfigPath: customPackageTemplate,
HandlerDir: handlerDir,
RouterDir: routerDir,
ModelDir: modelDir,
UseDir: args.Use,
ClientDir: clientDir,
TemplateGenerator: generator.TemplateGenerator{
OutputDir: args.OutDir,
Excludes: args.Excludes,
},
ProjPackage: pkg,
Options: options,
HandlerByMethod: args.HandlerByMethod,
CmdType: args.CmdType,
IdlClientDir: plugin.IdlClientDir,
ForceClientDir: args.ForceClientDir,
BaseDomain: args.BaseDomain,
SnakeStyleMiddleware: args.SnakeStyleMiddleware,
}
if args.ModelBackend != "" {
sg.Backend = meta.Backend(args.ModelBackend)
}
generator.SetDefaultTemplateConfig()
err = sg.Generate(idl)
if err != nil {
return nil, fmt.Errorf("generate http package error: %v", err)
}
files, err := sg.GetFormatAndExcludedFiles()
if err != nil {
return nil, fmt.Errorf("persist http package error: %v", err)
}
return files, nil
}