gr_hz/generator/handler.go

321 lines
9.2 KiB
Go
Raw Normal View History

2024-04-30 19:30:09 +08:00
/*
* Copyright 2022 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package generator
import (
"bytes"
"fmt"
"io/ioutil"
"path/filepath"
"strings"
2024-04-30 20:22:44 +08:00
"gr_hz/generator/model"
"gr_hz/util"
"gr_hz/util/logs"
2024-04-30 19:30:09 +08:00
)
type HttpMethod struct {
Name string
HTTPMethod string
Comment string
RequestTypeName string
RequestTypePackage string
RequestTypeRawName string
ReturnTypeName string
ReturnTypePackage string
ReturnTypeRawName string
Path string
Serializer string
OutputDir string
RefPackage string // handler import dir
RefPackageAlias string // handler import alias
ModelPackage map[string]string
GenHandler bool // Whether to generate one handler, when an idl interface corresponds to multiple http method
// Annotations map[string]string
Models map[string]*model.Model
}
type Handler struct {
FilePath string
PackageName string
ProjPackage string
Imports map[string]*model.Model
Methods []*HttpMethod
}
type SingleHandler struct {
*HttpMethod
FilePath string
PackageName string
ProjPackage string
}
type Client struct {
Handler
ServiceName string
}
func (pkgGen *HttpPackageGenerator) genHandler(pkg *HttpPackage, handlerDir, handlerPackage string, root *RouterNode) error {
for _, s := range pkg.Services {
var handler Handler
if pkgGen.HandlerByMethod { // generate handler by method
for _, m := range s.Methods {
filePath := filepath.Join(handlerDir, m.OutputDir, util.ToSnakeCase(m.Name)+".go")
handler = Handler{
FilePath: filePath,
PackageName: util.SplitPackage(filepath.Dir(filePath), ""),
Methods: []*HttpMethod{m},
ProjPackage: pkgGen.ProjPackage,
}
if err := pkgGen.processHandler(&handler, root, handlerDir, m.OutputDir, true); err != nil {
return fmt.Errorf("generate handler %s failed, err: %v", handler.FilePath, err.Error())
}
if m.GenHandler {
if err := pkgGen.updateHandler(handler, handlerTplName, handler.FilePath, false); err != nil {
return fmt.Errorf("generate handler %s failed, err: %v", handler.FilePath, err.Error())
}
}
}
} else { // generate handler service
tmpHandlerDir := handlerDir
tmpHandlerPackage := handlerPackage
if len(s.ServiceGenDir) != 0 {
tmpHandlerDir = s.ServiceGenDir
tmpHandlerPackage = util.SubPackage(pkgGen.ProjPackage, tmpHandlerDir)
}
handler = Handler{
FilePath: filepath.Join(tmpHandlerDir, util.ToSnakeCase(s.Name)+".go"),
PackageName: util.SplitPackage(tmpHandlerPackage, ""),
Methods: s.Methods,
ProjPackage: pkgGen.ProjPackage,
}
for _, m := range s.Methods {
m.RefPackage = tmpHandlerPackage
m.RefPackageAlias = util.BaseName(tmpHandlerPackage, "")
}
if err := pkgGen.processHandler(&handler, root, "", "", false); err != nil {
return fmt.Errorf("generate handler %s failed, err: %v", handler.FilePath, err.Error())
}
// Avoid generating duplicate handlers when IDL interface corresponds to multiple http methods
methods := handler.Methods
handler.Methods = []*HttpMethod{}
for _, m := range methods {
if m.GenHandler {
handler.Methods = append(handler.Methods, m)
}
}
if err := pkgGen.updateHandler(handler, handlerTplName, handler.FilePath, false); err != nil {
return fmt.Errorf("generate handler %s failed, err: %v", handler.FilePath, err.Error())
}
}
if len(pkgGen.ClientDir) != 0 {
clientDir := util.SubDir(pkgGen.ClientDir, pkg.Package)
clientPackage := util.SubPackage(pkgGen.ProjPackage, clientDir)
client := Client{}
client.Handler = handler
client.ServiceName = s.Name
client.PackageName = util.SplitPackage(clientPackage, "")
client.FilePath = filepath.Join(clientDir, util.ToSnakeCase(s.Name)+".go")
if err := pkgGen.updateClient(client, clientTplName, client.FilePath, false); err != nil {
return fmt.Errorf("generate client %s failed, err: %v", client.FilePath, err.Error())
}
}
}
return nil
}
func (pkgGen *HttpPackageGenerator) processHandler(handler *Handler, root *RouterNode, handlerDir, projectOutDir string, handlerByMethod bool) error {
singleHandlerPackage := ""
if handlerByMethod {
singleHandlerPackage = util.SubPackage(pkgGen.ProjPackage, filepath.Join(handlerDir, projectOutDir))
}
handler.Imports = make(map[string]*model.Model, len(handler.Methods))
for _, m := range handler.Methods {
// Iterate over the request and return parameters of the method to get import path.
for key, mm := range m.Models {
if v, ok := handler.Imports[mm.PackageName]; ok && v.Package != mm.Package {
handler.Imports[key] = mm
continue
}
handler.Imports[mm.PackageName] = mm
}
err := root.Update(m, handler.PackageName, singleHandlerPackage)
if err != nil {
return err
}
}
if len(pkgGen.UseDir) != 0 {
oldModelDir := filepath.Clean(filepath.Join(pkgGen.ProjPackage, pkgGen.ModelDir))
newModelDir := filepath.Clean(pkgGen.UseDir)
for _, m := range handler.Methods {
for _, mm := range m.Models {
mm.Package = strings.Replace(mm.Package, oldModelDir, newModelDir, 1)
}
}
}
handler.Format()
return nil
}
func (pkgGen *HttpPackageGenerator) updateHandler(handler interface{}, handlerTpl, filePath string, noRepeat bool) error {
if pkgGen.tplsInfo[handlerTpl].Disable {
return nil
}
isExist, err := util.PathExist(filePath)
if err != nil {
return err
}
if !isExist {
return pkgGen.TemplateGenerator.Generate(handler, handlerTpl, filePath, noRepeat)
}
if pkgGen.HandlerByMethod { // method by handler, do not need to insert new content
return nil
}
file, err := ioutil.ReadFile(filePath)
if err != nil {
return err
}
// insert new model imports
for alias, model := range handler.(Handler).Imports {
if bytes.Contains(file, []byte(model.Package)) {
continue
}
file, err = util.AddImportForContent(file, alias, model.Package)
if err != nil {
return err
}
}
// insert customized imports
if tplInfo, exist := pkgGen.TemplateGenerator.tplsInfo[handlerTpl]; exist {
if len(tplInfo.UpdateBehavior.ImportTpl) != 0 {
imptSlice, err := getInsertImportContent(tplInfo, handler, file)
if err != nil {
return err
}
for _, impt := range imptSlice {
if bytes.Contains(file, []byte(impt[1])) {
continue
}
file, err = util.AddImportForContent(file, impt[0], impt[1])
if err != nil {
logs.Warnf("can not add import(%s) for file(%s), err: %v\n", impt[1], filePath, err)
}
}
}
}
// insert new handler
for _, method := range handler.(Handler).Methods {
if bytes.Contains(file, []byte(fmt.Sprintf("func %s(", method.Name))) {
continue
}
// Generate additional handlers using templates
handlerSingleTpl := pkgGen.tpls[handlerSingleTplName]
if handlerSingleTpl == nil {
return fmt.Errorf("tpl %s not found", handlerSingleTplName)
}
data := SingleHandler{
HttpMethod: method,
FilePath: handler.(Handler).FilePath,
PackageName: handler.(Handler).PackageName,
ProjPackage: handler.(Handler).ProjPackage,
}
handlerFunc := bytes.NewBuffer(nil)
err = handlerSingleTpl.Execute(handlerFunc, data)
if err != nil {
return fmt.Errorf("execute template \"%s\" failed, %v", handlerSingleTplName, err)
}
buf := bytes.NewBuffer(nil)
_, err = buf.Write(file)
if err != nil {
return fmt.Errorf("write handler \"%s\" failed, %v", method.Name, err)
}
_, err = buf.Write(handlerFunc.Bytes())
if err != nil {
return fmt.Errorf("write handler \"%s\" failed, %v", method.Name, err)
}
file = buf.Bytes()
}
pkgGen.files = append(pkgGen.files, File{filePath, string(file), false, ""})
return nil
}
func (pkgGen *HttpPackageGenerator) updateClient(client interface{}, clientTpl, filePath string, noRepeat bool) error {
isExist, err := util.PathExist(filePath)
if err != nil {
return err
}
if !isExist {
return pkgGen.TemplateGenerator.Generate(client, clientTpl, filePath, noRepeat)
}
logs.Infof("Client file:%s has been generated, so don't update it", filePath)
return nil
}
func (m *HttpMethod) InitComment() {
text := strings.TrimLeft(strings.TrimSpace(m.Comment), "/")
if text == "" {
text = "// " + m.Name + " ."
} else if strings.HasPrefix(text, m.Name) {
text = "// " + text
} else {
text = "// " + m.Name + " " + text
}
text = strings.Replace(text, "\n", "\n// ", -1)
if !strings.Contains(text, "@router ") {
text += "\n// @router " + m.Path
}
m.Comment = text + " [" + m.HTTPMethod + "]"
}
func MapSerializer(serializer string) string {
switch serializer {
case "json":
return "JSON"
case "thrift":
return "Thrift"
case "pb":
return "ProtoBuf"
default:
return "JSON"
}
}
func (h *Handler) Format() {
for _, m := range h.Methods {
m.Serializer = MapSerializer(m.Serializer)
m.InitComment()
}
}