gr_hz/thrift/resolver.go
2024-04-30 23:07:06 +08:00

593 lines
14 KiB
Go

/*
* Copyright 2022 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package thrift
import (
"fmt"
"strings"
"github.com/cloudwego/thriftgo/parser"
"golib.gaore.com/GaoreGo/gr_hz/generator/model"
"golib.gaore.com/GaoreGo/gr_hz/util"
)
var (
ConstTrue = Symbol{
IsValue: true,
Type: model.TypeBool,
Value: true,
Scope: &BaseThrift,
}
ConstFalse = Symbol{
IsValue: true,
Type: model.TypeBool,
Value: false,
Scope: &BaseThrift,
}
ConstEmptyString = Symbol{
IsValue: true,
Type: model.TypeString,
Value: "",
Scope: &BaseThrift,
}
)
type PackageReference struct {
IncludeBase string
IncludePath string
Model *model.Model
Ast *parser.Thrift
Referred bool
}
func getReferPkgMap(pkgMap map[string]string, incs []*parser.Include, mainModel *model.Model) (map[string]*PackageReference, error) {
var err error
out := make(map[string]*PackageReference, len(pkgMap))
pkgAliasMap := make(map[string]string, len(incs))
// bugfix: add main package to avoid namespace conflict
mainPkg := mainModel.Package
mainPkgName := mainModel.PackageName
mainPkgName, err = util.GetPackageUniqueName(mainPkgName)
if err != nil {
return nil, err
}
pkgAliasMap[mainPkg] = mainPkgName
for _, inc := range incs {
pkg := getGoPackage(inc.Reference, pkgMap)
impt := inc.GetPath()
base := util.BaseNameAndTrim(impt)
pkgName := util.SplitPackageName(pkg, "")
if pn, exist := pkgAliasMap[pkg]; exist {
pkgName = pn
} else {
pkgName, err = util.GetPackageUniqueName(pkgName)
pkgAliasMap[pkg] = pkgName
if err != nil {
return nil, fmt.Errorf("get package unique name failed, err: %v", err)
}
}
out[base] = &PackageReference{base, impt, &model.Model{
FilePath: inc.Path,
Package: pkg,
PackageName: pkgName,
}, inc.Reference, false}
}
return out, nil
}
type Symbol struct {
IsValue bool
Type *model.Type
Value interface{}
Scope *parser.Thrift
}
type NameSpace map[string]*Symbol
type Resolver struct {
// idl symbols
root NameSpace
deps map[string]NameSpace
// exported models
mainPkg PackageReference
refPkgs map[string]*PackageReference
}
func NewResolver(ast *parser.Thrift, model *model.Model, pkgMap map[string]string) (*Resolver, error) {
pm, err := getReferPkgMap(pkgMap, ast.GetIncludes(), model)
if err != nil {
return nil, fmt.Errorf("get package map failed, err: %v", err)
}
file := ast.GetFilename()
return &Resolver{
root: make(NameSpace),
deps: make(map[string]NameSpace),
refPkgs: pm,
mainPkg: PackageReference{
IncludeBase: util.BaseNameAndTrim(file),
IncludePath: ast.GetFilename(),
Model: model,
Ast: ast,
Referred: false,
},
}, nil
}
func (resolver *Resolver) GetRefModel(includeBase string) (*model.Model, error) {
if includeBase == "" {
return resolver.mainPkg.Model, nil
}
ref, ok := resolver.refPkgs[includeBase]
if !ok {
return nil, fmt.Errorf("not found include %s", includeBase)
}
return ref.Model, nil
}
func (resolver *Resolver) getBaseType(typ *parser.Type) (*model.Type, bool) {
tt := switchBaseType(typ)
if tt != nil {
return tt, true
}
if typ.Name == "map" {
t := *model.TypeBaseMap
return &t, false
}
if typ.Name == "list" {
t := *model.TypeBaseList
return &t, false
}
if typ.Name == "set" {
t := *model.TypeBaseList
return &t, false
}
return nil, false
}
func (resolver *Resolver) ResolveType(typ *parser.Type) (*model.Type, error) {
bt, base := resolver.getBaseType(typ)
if bt != nil {
if base {
return bt, nil
} else {
if typ.Name == model.TypeBaseMap.Name {
resolveKey, err := resolver.ResolveType(typ.KeyType)
if err != nil {
return nil, err
}
resolveValue, err := resolver.ResolveType(typ.ValueType)
if err != nil {
return nil, err
}
bt.Extra = append(bt.Extra, resolveKey, resolveValue)
} else if typ.Name == model.TypeBaseList.Name || typ.Name == model.TypeBaseSet.Name {
resolveValue, err := resolver.ResolveType(typ.ValueType)
if err != nil {
return nil, err
}
bt.Extra = append(bt.Extra, resolveValue)
} else {
return nil, fmt.Errorf("invalid DefinitionType(%+v)", bt)
}
return bt, nil
}
}
id := typ.GetName()
rs, err := resolver.ResolveIdentifier(id)
if err != nil {
return nil, err
}
sb := rs.Symbol
if sb == nil {
return nil, fmt.Errorf("not found identifier %s", id)
}
return sb.Type, nil
}
func (resolver *Resolver) ResolveConstantValue(constant *parser.ConstValue) (model.Literal, error) {
switch constant.Type {
case parser.ConstType_ConstInt:
return model.IntExpression{Src: int(constant.TypedValue.GetInt())}, nil
case parser.ConstType_ConstDouble:
return model.DoubleExpression{Src: constant.TypedValue.GetDouble()}, nil
case parser.ConstType_ConstLiteral:
return model.StringExpression{Src: constant.TypedValue.GetLiteral()}, nil
case parser.ConstType_ConstList:
eleType, err := switchConstantType(constant.Type)
if err != nil {
return nil, err
}
ret := model.ListExpression{
ElementType: eleType,
}
for _, i := range constant.TypedValue.List {
elem, err := resolver.ResolveConstantValue(i)
if err != nil {
return nil, err
}
ret.Elements = append(ret.Elements, elem)
}
return ret, nil
case parser.ConstType_ConstMap:
keyType, err := switchConstantType(constant.TypedValue.Map[0].Key.Type)
if err != nil {
return nil, err
}
valueType, err := switchConstantType(constant.TypedValue.Map[0].Value.Type)
if err != nil {
return nil, err
}
ret := model.MapExpression{
KeyType: keyType,
ValueType: valueType,
Elements: make(map[string]model.Literal, len(constant.TypedValue.Map)),
}
for _, v := range constant.TypedValue.Map {
value, err := resolver.ResolveConstantValue(v.Value)
if err != nil {
return nil, err
}
ret.Elements[v.Key.String()] = value
}
return ret, nil
case parser.ConstType_ConstIdentifier:
return resolver.ResolveIdentifier(*constant.TypedValue.Identifier)
}
return model.StringExpression{Src: constant.String()}, nil
}
func (resolver *Resolver) ResolveIdentifier(id string) (ret ResolvedSymbol, err error) {
sb := resolver.Get(id)
if sb == nil {
return ResolvedSymbol{}, fmt.Errorf("identifier '%s' not found", id)
}
ret.Symbol = sb
ret.Base = id
if sb.Scope == &BaseThrift {
return
}
if sb.Scope == resolver.mainPkg.Ast {
resolver.mainPkg.Referred = true
ret.Src = resolver.mainPkg.Model.PackageName
return
}
sp := strings.SplitN(id, ".", 2)
if ref, ok := resolver.refPkgs[sp[0]]; ok {
ref.Referred = true
ret.Base = sp[1]
ret.Src = ref.Model.PackageName
ret.Type.Scope = ref.Model
} else {
return ResolvedSymbol{}, fmt.Errorf("can't resolve identifier '%s'", id)
}
return
}
func (resolver *Resolver) ResolveTypeName(typ *parser.Type) (string, error) {
if typ.GetIsTypedef() {
rt, err := resolver.ResolveIdentifier(typ.GetName())
if err != nil {
return "", err
}
return rt.Expression(), nil
}
switch typ.GetCategory() {
case parser.Category_Map:
keyType, err := resolver.ResolveTypeName(typ.GetKeyType())
if err != nil {
return "", err
}
if typ.GetKeyType().GetCategory().IsStruct() {
keyType = "*" + keyType
}
valueType, err := resolver.ResolveTypeName(typ.GetValueType())
if err != nil {
return "", err
}
if typ.GetValueType().GetCategory().IsStruct() {
valueType = "*" + valueType
}
return fmt.Sprintf("map[%s]%s", keyType, valueType), nil
case parser.Category_List, parser.Category_Set:
// list/set -> []element for thriftgo
// valueType refers the element type for list/set
elemType, err := resolver.ResolveTypeName(typ.GetValueType())
if err != nil {
return "", err
}
if typ.GetValueType().GetCategory().IsStruct() {
elemType = "*" + elemType
}
return fmt.Sprintf("[]%s", elemType), err
}
rt, err := resolver.ResolveIdentifier(typ.GetName())
if err != nil {
return "", err
}
return rt.Expression(), nil
}
func (resolver *Resolver) Get(name string) *Symbol {
s, ok := resolver.root[name]
if ok {
return s
}
if strings.Contains(name, ".") {
sp := strings.SplitN(name, ".", 2)
if ref, ok := resolver.deps[sp[0]]; ok {
if ss, ok := ref[sp[1]]; ok {
return ss
}
}
}
return nil
}
func (resolver *Resolver) ExportReferred(all, needMain bool) (ret []*PackageReference) {
for _, v := range resolver.refPkgs {
if all {
ret = append(ret, v)
v.Referred = false
} else if v.Referred {
ret = append(ret, v)
v.Referred = false
}
}
if needMain && (all || resolver.mainPkg.Referred) {
ret = append(ret, &resolver.mainPkg)
}
resolver.mainPkg.Referred = false
return
}
func (resolver *Resolver) LoadAll(ast *parser.Thrift) error {
var err error
resolver.root, err = resolver.LoadOne(ast)
if err != nil {
return fmt.Errorf("load root package: %s", err)
}
includes := ast.GetIncludes()
astMap := make(map[string]NameSpace, len(includes))
for _, dep := range includes {
bName := util.BaseName(dep.Path, ".thrift")
astMap[bName], err = resolver.LoadOne(dep.Reference)
if err != nil {
return fmt.Errorf("load idl %s: %s", dep.Path, err)
}
}
resolver.deps = astMap
for _, td := range ast.Typedefs {
name := td.GetAlias()
if _, ex := resolver.root[name]; ex {
if resolver.root[name].Type != nil {
typ := newTypedefType(resolver.root[name].Type, name)
resolver.root[name].Type = &typ
continue
}
}
sym := resolver.Get(td.Type.GetName())
typ := newTypedefType(sym.Type, name)
resolver.root[name].Type = &typ
}
return nil
}
func LoadBaseIdentifier() NameSpace {
ret := make(NameSpace, 16)
ret["true"] = &ConstTrue
ret["false"] = &ConstFalse
ret[`""`] = &ConstEmptyString
ret["bool"] = &Symbol{
Type: model.TypeBool,
Scope: &BaseThrift,
}
ret["byte"] = &Symbol{
Type: model.TypeByte,
Scope: &BaseThrift,
}
ret["i8"] = &Symbol{
Type: model.TypeInt8,
Scope: &BaseThrift,
}
ret["i16"] = &Symbol{
Type: model.TypeInt16,
Scope: &BaseThrift,
}
ret["i32"] = &Symbol{
Type: model.TypeInt32,
Scope: &BaseThrift,
}
ret["i64"] = &Symbol{
Type: model.TypeInt64,
Scope: &BaseThrift,
}
ret["int"] = &Symbol{
Type: model.TypeInt,
Scope: &BaseThrift,
}
ret["double"] = &Symbol{
Type: model.TypeFloat64,
Scope: &BaseThrift,
}
ret["string"] = &Symbol{
Type: model.TypeString,
Scope: &BaseThrift,
}
ret["binary"] = &Symbol{
Type: model.TypeBinary,
Scope: &BaseThrift,
}
ret["list"] = &Symbol{
Type: model.TypeBaseList,
Scope: &BaseThrift,
}
ret["set"] = &Symbol{
Type: model.TypeBaseSet,
Scope: &BaseThrift,
}
ret["map"] = &Symbol{
Type: model.TypeBaseMap,
Scope: &BaseThrift,
}
return ret
}
func (resolver *Resolver) LoadOne(ast *parser.Thrift) (NameSpace, error) {
ret := LoadBaseIdentifier()
for _, e := range ast.Enums {
prefix := e.GetName()
ret[prefix] = &Symbol{
IsValue: false,
Value: e,
Scope: ast,
Type: newEnumType(prefix, model.CategoryEnum),
}
for _, ee := range e.Values {
name := prefix + "." + ee.GetName()
if _, exist := ret[name]; exist {
return nil, fmt.Errorf("duplicated identifier '%s' in %s", name, ast.Filename)
}
ret[name] = &Symbol{
IsValue: true,
Value: ee,
Scope: ast,
Type: newBaseType(model.TypeInt, model.CategoryEnum),
}
}
}
for _, e := range ast.Constants {
name := e.GetName()
if _, exist := ret[name]; exist {
return nil, fmt.Errorf("duplicated identifier '%s' in %s", name, ast.Filename)
}
gt, _ := resolver.getBaseType(e.Type)
ret[name] = &Symbol{
IsValue: true,
Value: e,
Scope: ast,
Type: gt,
}
}
for _, e := range ast.Structs {
name := e.GetName()
if _, exist := ret[name]; exist {
return nil, fmt.Errorf("duplicated identifier '%s' in %s", name, ast.Filename)
}
ret[name] = &Symbol{
IsValue: false,
Value: e,
Scope: ast,
Type: newStructType(name, model.CategoryStruct),
}
}
for _, e := range ast.Unions {
name := e.GetName()
if _, exist := ret[name]; exist {
return nil, fmt.Errorf("duplicated identifier '%s' in %s", name, ast.Filename)
}
ret[name] = &Symbol{
IsValue: false,
Value: e,
Scope: ast,
Type: newStructType(name, model.CategoryStruct),
}
}
for _, e := range ast.Exceptions {
name := e.GetName()
if _, exist := ret[name]; exist {
return nil, fmt.Errorf("duplicated identifier '%s' in %s", name, ast.Filename)
}
ret[name] = &Symbol{
IsValue: false,
Value: e,
Scope: ast,
Type: newStructType(name, model.CategoryStruct),
}
}
for _, e := range ast.Services {
name := e.GetName()
if _, exist := ret[name]; exist {
return nil, fmt.Errorf("duplicated identifier '%s' in %s", name, ast.Filename)
}
ret[name] = &Symbol{
IsValue: false,
Value: e,
Scope: ast,
Type: newFuncType(name, model.CategoryService),
}
}
for _, td := range ast.Typedefs {
name := td.GetAlias()
if _, exist := ret[name]; exist {
return nil, fmt.Errorf("duplicated identifier '%s' in %s", name, ast.Filename)
}
gt, _ := resolver.getBaseType(td.Type)
if gt == nil {
sym := ret[td.Type.Name]
if sym != nil {
gt = sym.Type
}
}
ret[name] = &Symbol{
IsValue: false,
Value: td,
Scope: ast,
Type: gt,
}
}
return ret, nil
}
func switchConstantType(constant parser.ConstType) (*model.Type, error) {
switch constant {
case parser.ConstType_ConstInt:
return model.TypeInt, nil
case parser.ConstType_ConstDouble:
return model.TypeFloat64, nil
case parser.ConstType_ConstLiteral:
return model.TypeString, nil
default:
return nil, fmt.Errorf("unknown constant type %d", constant)
}
}
func newTypedefType(t *model.Type, name string) model.Type {
tmp := t
typ := *tmp
typ.Name = name
typ.Category = model.CategoryTypedef
return typ
}