446 lines
11 KiB
Go
446 lines
11 KiB
Go
/*
|
|
* Copyright 2022 CloudWeGo Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package thrift
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/cloudwego/thriftgo/generator/backend"
|
|
"github.com/cloudwego/thriftgo/generator/golang"
|
|
"github.com/cloudwego/thriftgo/generator/golang/styles"
|
|
"github.com/cloudwego/thriftgo/parser"
|
|
thriftgo_plugin "github.com/cloudwego/thriftgo/plugin"
|
|
"std.gaore.com/Gaore-Go/gr_hz/config"
|
|
"std.gaore.com/Gaore-Go/gr_hz/generator"
|
|
"std.gaore.com/Gaore-Go/gr_hz/generator/model"
|
|
"std.gaore.com/Gaore-Go/gr_hz/meta"
|
|
"std.gaore.com/Gaore-Go/gr_hz/util"
|
|
"std.gaore.com/Gaore-Go/gr_hz/util/logs"
|
|
)
|
|
|
|
type Plugin struct {
|
|
req *thriftgo_plugin.Request
|
|
args *config.Argument
|
|
logger *logs.StdLogger
|
|
rmTags []string
|
|
}
|
|
|
|
func (plugin *Plugin) Run() int {
|
|
plugin.setLogger()
|
|
args := &config.Argument{}
|
|
defer func() {
|
|
if args == nil {
|
|
return
|
|
}
|
|
if args.Verbose {
|
|
verboseLog := plugin.recvVerboseLogger()
|
|
if len(verboseLog) != 0 {
|
|
fmt.Fprintf(os.Stderr, verboseLog)
|
|
}
|
|
} else {
|
|
warning := plugin.recvWarningLogger()
|
|
if len(warning) != 0 {
|
|
fmt.Fprintf(os.Stderr, warning)
|
|
}
|
|
}
|
|
}()
|
|
|
|
err := plugin.handleRequest()
|
|
if err != nil {
|
|
logs.Errorf("handle request failed: %s", err.Error())
|
|
return meta.PluginError
|
|
}
|
|
|
|
args, err = plugin.parseArgs()
|
|
if err != nil {
|
|
logs.Errorf("parse args failed: %s", err.Error())
|
|
return meta.PluginError
|
|
}
|
|
plugin.rmTags = args.RmTags
|
|
if args.CmdType == meta.CmdModel {
|
|
// check tag options for model mode
|
|
CheckTagOption(plugin.args)
|
|
res, err := plugin.GetResponse(nil, args.OutDir)
|
|
if err != nil {
|
|
logs.Errorf("get response failed: %s", err.Error())
|
|
return meta.PluginError
|
|
}
|
|
plugin.response(res)
|
|
if err != nil {
|
|
logs.Errorf("response failed: %s", err.Error())
|
|
return meta.PluginError
|
|
}
|
|
return 0
|
|
}
|
|
|
|
err = plugin.initNameStyle()
|
|
if err != nil {
|
|
logs.Errorf("init naming style failed: %s", err.Error())
|
|
return meta.PluginError
|
|
}
|
|
|
|
options := CheckTagOption(plugin.args)
|
|
|
|
pkgInfo, err := plugin.getPackageInfo()
|
|
if err != nil {
|
|
logs.Errorf("get http package info failed: %s", err.Error())
|
|
return meta.PluginError
|
|
}
|
|
|
|
customPackageTemplate := args.CustomizePackage
|
|
pkg, err := args.GetGoPackage()
|
|
if err != nil {
|
|
logs.Errorf("get go package failed: %s", err.Error())
|
|
return meta.PluginError
|
|
}
|
|
handlerDir, err := args.GetHandlerDir()
|
|
if err != nil {
|
|
logs.Errorf("get handler dir failed: %s", err.Error())
|
|
return meta.PluginError
|
|
}
|
|
routerDir, err := args.GetRouterDir()
|
|
if err != nil {
|
|
logs.Errorf("get router dir failed: %s", err.Error())
|
|
return meta.PluginError
|
|
}
|
|
modelDir, err := args.GetModelDir()
|
|
if err != nil {
|
|
logs.Errorf("get model dir failed: %s", err.Error())
|
|
return meta.PluginError
|
|
}
|
|
clientDir, err := args.GetClientDir()
|
|
if err != nil {
|
|
logs.Errorf("get client dir failed: %s", err.Error())
|
|
return meta.PluginError
|
|
}
|
|
sg := generator.HttpPackageGenerator{
|
|
ConfigPath: customPackageTemplate,
|
|
HandlerDir: handlerDir,
|
|
RouterDir: routerDir,
|
|
ModelDir: modelDir,
|
|
UseDir: args.Use,
|
|
ClientDir: clientDir,
|
|
TemplateGenerator: generator.TemplateGenerator{
|
|
OutputDir: args.OutDir,
|
|
Excludes: args.Excludes,
|
|
},
|
|
ProjPackage: pkg,
|
|
Options: options,
|
|
HandlerByMethod: args.HandlerByMethod,
|
|
CmdType: args.CmdType,
|
|
IdlClientDir: util.SubDir(modelDir, pkgInfo.Package),
|
|
ForceClientDir: args.ForceClientDir,
|
|
BaseDomain: args.BaseDomain,
|
|
SnakeStyleMiddleware: args.SnakeStyleMiddleware,
|
|
}
|
|
if args.ModelBackend != "" {
|
|
sg.Backend = meta.Backend(args.ModelBackend)
|
|
}
|
|
generator.SetDefaultTemplateConfig()
|
|
|
|
err = sg.Generate(pkgInfo)
|
|
if err != nil {
|
|
logs.Errorf("generate package failed: %s", err.Error())
|
|
return meta.PluginError
|
|
}
|
|
if len(args.Use) != 0 {
|
|
err = sg.Persist()
|
|
if err != nil {
|
|
logs.Errorf("persist file failed within '-use' option: %s", err.Error())
|
|
return meta.PluginError
|
|
}
|
|
res := thriftgo_plugin.BuildErrorResponse(errors.New(meta.TheUseOptionMessage).Error())
|
|
err = plugin.response(res)
|
|
if err != nil {
|
|
logs.Errorf("response failed: %s", err.Error())
|
|
return meta.PluginError
|
|
}
|
|
return 0
|
|
}
|
|
files, err := sg.GetFormatAndExcludedFiles()
|
|
if err != nil {
|
|
logs.Errorf("format file failed: %s", err.Error())
|
|
return meta.PluginError
|
|
}
|
|
res, err := plugin.GetResponse(files, sg.OutputDir)
|
|
if err != nil {
|
|
logs.Errorf("get response failed: %s", err.Error())
|
|
return meta.PluginError
|
|
}
|
|
err = plugin.response(res)
|
|
if err != nil {
|
|
logs.Errorf("response failed: %s", err.Error())
|
|
return meta.PluginError
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func (plugin *Plugin) setLogger() {
|
|
plugin.logger = logs.NewStdLogger(logs.LevelInfo)
|
|
plugin.logger.Defer = true
|
|
plugin.logger.ErrOnly = true
|
|
logs.SetLogger(plugin.logger)
|
|
}
|
|
|
|
func (plugin *Plugin) recvWarningLogger() string {
|
|
warns := plugin.logger.Warn()
|
|
plugin.logger.Flush()
|
|
logs.SetLogger(logs.NewStdLogger(logs.LevelInfo))
|
|
return warns
|
|
}
|
|
|
|
func (plugin *Plugin) recvVerboseLogger() string {
|
|
info := plugin.logger.Out()
|
|
warns := plugin.logger.Warn()
|
|
verboseLog := string(info) + warns
|
|
plugin.logger.Flush()
|
|
logs.SetLogger(logs.NewStdLogger(logs.LevelInfo))
|
|
return verboseLog
|
|
}
|
|
|
|
func (plugin *Plugin) handleRequest() error {
|
|
data, err := ioutil.ReadAll(os.Stdin)
|
|
if err != nil {
|
|
return fmt.Errorf("read request failed: %s", err.Error())
|
|
}
|
|
req, err := thriftgo_plugin.UnmarshalRequest(data)
|
|
if err != nil {
|
|
return fmt.Errorf("unmarshal request failed: %s", err.Error())
|
|
}
|
|
plugin.req = req
|
|
// init thriftgo utils
|
|
thriftgoUtil = golang.NewCodeUtils(backend.DummyLogFunc())
|
|
thriftgoUtil.HandleOptions(req.GeneratorParameters)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (plugin *Plugin) parseArgs() (*config.Argument, error) {
|
|
if plugin.req == nil {
|
|
return nil, fmt.Errorf("request is nil")
|
|
}
|
|
args := new(config.Argument)
|
|
err := args.Unpack(plugin.req.PluginParameters)
|
|
if err != nil {
|
|
logs.Errorf("unpack args failed: %s", err.Error())
|
|
}
|
|
plugin.args = args
|
|
return args, nil
|
|
}
|
|
|
|
// initNameStyle initializes the naming style based on the "naming_style" option for thrift.
|
|
func (plugin *Plugin) initNameStyle() error {
|
|
if len(plugin.args.ThriftOptions) == 0 {
|
|
return nil
|
|
}
|
|
for _, opt := range plugin.args.ThriftOptions {
|
|
parts := strings.SplitN(opt, "=", 2)
|
|
if len(parts) == 2 && parts[0] == "naming_style" {
|
|
NameStyle = styles.NewNamingStyle(parts[1])
|
|
if NameStyle == nil {
|
|
return fmt.Errorf(fmt.Sprintf("do not support \"%s\" naming style", parts[1]))
|
|
}
|
|
break
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (plugin *Plugin) getPackageInfo() (*generator.HttpPackage, error) {
|
|
req := plugin.req
|
|
args := plugin.args
|
|
|
|
ast := req.GetAST()
|
|
if ast == nil {
|
|
return nil, fmt.Errorf("no ast")
|
|
}
|
|
logs.Infof("Processing %s", ast.GetFilename())
|
|
|
|
pkgMap := args.OptPkgMap
|
|
pkg := getGoPackage(ast, pkgMap)
|
|
main := &model.Model{
|
|
FilePath: ast.Filename,
|
|
Package: pkg,
|
|
PackageName: util.SplitPackageName(pkg, ""),
|
|
}
|
|
rs, err := NewResolver(ast, main, pkgMap)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("new thrift resolver failed, err:%v", err)
|
|
}
|
|
err = rs.LoadAll(ast)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
idlPackage := getGoPackage(ast, pkgMap)
|
|
if idlPackage == "" {
|
|
return nil, fmt.Errorf("go package for '%s' is not defined", ast.GetFilename())
|
|
}
|
|
|
|
services, err := astToService(ast, rs, args)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var models model.Models
|
|
for _, s := range services {
|
|
models.MergeArray(s.Models)
|
|
}
|
|
|
|
return &generator.HttpPackage{
|
|
Services: services,
|
|
IdlName: ast.GetFilename(),
|
|
Package: idlPackage,
|
|
Models: models,
|
|
}, nil
|
|
}
|
|
|
|
func (plugin *Plugin) response(res *thriftgo_plugin.Response) error {
|
|
data, err := thriftgo_plugin.MarshalResponse(res)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal response failed: %s", err.Error())
|
|
}
|
|
_, err = os.Stdout.Write(data)
|
|
if err != nil {
|
|
return fmt.Errorf("write response failed: %s", err.Error())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (plugin *Plugin) InsertTag() ([]*thriftgo_plugin.Generated, error) {
|
|
var res []*thriftgo_plugin.Generated
|
|
|
|
if plugin.args.NoRecurse {
|
|
outPath := plugin.req.OutputPath
|
|
packageName := getGoPackage(plugin.req.AST, nil)
|
|
fileName := util.BaseNameAndTrim(plugin.req.AST.GetFilename()) + ".go"
|
|
outPath = filepath.Join(outPath, packageName, fileName)
|
|
for _, st := range plugin.req.AST.Structs {
|
|
stName := st.GetName()
|
|
for _, f := range st.Fields {
|
|
fieldName := f.GetName()
|
|
tagString, err := getTagString(f, plugin.rmTags)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
insertPointer := "struct." + stName + "." + fieldName + "." + "tag"
|
|
gen := &thriftgo_plugin.Generated{
|
|
Content: tagString,
|
|
Name: &outPath,
|
|
InsertionPoint: &insertPointer,
|
|
}
|
|
res = append(res, gen)
|
|
}
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
for ast := range plugin.req.AST.DepthFirstSearch() {
|
|
outPath := plugin.req.OutputPath
|
|
packageName := getGoPackage(ast, nil)
|
|
fileName := util.BaseNameAndTrim(ast.GetFilename()) + ".go"
|
|
outPath = filepath.Join(outPath, packageName, fileName)
|
|
|
|
for _, st := range ast.Structs {
|
|
stName := st.GetName()
|
|
for _, f := range st.Fields {
|
|
fieldName := f.GetName()
|
|
tagString, err := getTagString(f, plugin.rmTags)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
insertPointer := "struct." + stName + "." + fieldName + "." + "tag"
|
|
gen := &thriftgo_plugin.Generated{
|
|
Content: tagString,
|
|
Name: &outPath,
|
|
InsertionPoint: &insertPointer,
|
|
}
|
|
res = append(res, gen)
|
|
}
|
|
}
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
func (plugin *Plugin) GetResponse(files []generator.File, outputDir string) (*thriftgo_plugin.Response, error) {
|
|
var contents []*thriftgo_plugin.Generated
|
|
for _, file := range files {
|
|
filePath := filepath.Join(outputDir, file.Path)
|
|
content := &thriftgo_plugin.Generated{
|
|
Content: file.Content,
|
|
Name: &filePath,
|
|
}
|
|
contents = append(contents, content)
|
|
}
|
|
|
|
insertTag, err := plugin.InsertTag()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
contents = append(contents, insertTag...)
|
|
|
|
return &thriftgo_plugin.Response{
|
|
Contents: contents,
|
|
}, nil
|
|
}
|
|
|
|
func getTagString(f *parser.Field, rmTags []string) (string, error) {
|
|
field := model.Field{}
|
|
err := injectTags(f, &field, true, false)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
disableTag := false
|
|
if v := getAnnotation(f.Annotations, AnnotationNone); len(v) > 0 {
|
|
if strings.EqualFold(v[0], "true") {
|
|
disableTag = true
|
|
}
|
|
}
|
|
|
|
for _, rmTag := range rmTags {
|
|
for _, t := range field.Tags {
|
|
if t.IsDefault && strings.EqualFold(t.Key, rmTag) {
|
|
field.Tags.Remove(t.Key)
|
|
}
|
|
}
|
|
}
|
|
|
|
var tagString string
|
|
tags := field.Tags
|
|
for idx, tag := range tags {
|
|
value := tag.Value
|
|
if disableTag {
|
|
value = "-"
|
|
}
|
|
if idx == 0 {
|
|
tagString += " " + tag.Key + ":\"" + value + "\"" + " "
|
|
} else if idx == len(tags)-1 {
|
|
tagString += tag.Key + ":\"" + value + "\""
|
|
} else {
|
|
tagString += tag.Key + ":\"" + value + "\"" + " "
|
|
}
|
|
}
|
|
|
|
return tagString, nil
|
|
}
|