// Code generated by hertz generator. package main import ( "flag" "fmt" "io" "io/fs" "log" "os" "os/exec" "path/filepath" "strings" ) func main() { if len(os.Args) < 2 { fmt.Println("Usage: hertz_scaffold create -p ") os.Exit(1) } // 解析子命令 subCommand := os.Args[1] if subCommand != "create" && subCommand != "update" { log.Printf("Unknown subcommand: %s\n", subCommand) os.Exit(1) } switch subCommand { case "create": handleCreateCommand(os.Args[2:]) } } func handleCreateCommand(args []string) { var ( template = "git@golib-ssh.gaore.com:GaoreGo/hertz_demo.git" branch = "master" projectPath string err error ) flagSet := flag.NewFlagSet("create", flag.ExitOnError) flagSet.StringVar(&projectPath, "p", "", "Path to the new project") flagSet.Parse(args) if projectPath == "" { log.Printf("Error: -p flag is required\n") os.Exit(1) } err = handleRemoteTemplate(template, branch, projectPath) if err != nil { log.Printf("Error creating project: %s\n", err) os.Exit(1) } fmt.Printf("Project %s created successfully!\n", projectPath) } func getModuleName(goModPath string) (string, error) { content, err := os.ReadFile(goModPath) if err != nil { return "", err } lines := strings.Split(string(content), "\n") for _, line := range lines { if strings.HasPrefix(line, "module ") { return strings.TrimSpace(strings.TrimPrefix(line, "module ")), nil } } return "", fmt.Errorf("module name not found in go.mod") } func handleRemoteTemplate(templateRepo, branch, projectPath string) (err error) { // 创建临时目录 tempDir, err := os.MkdirTemp("", "template-*") if err != nil { return fmt.Errorf("error creating temporary directory: %s", err) } // 清理临时目录 defer os.RemoveAll(tempDir) // 克隆模板仓库 cloneCmd := exec.Command("git", "clone", "-b", branch, templateRepo, tempDir) cloneCmd.Stdout = os.Stdout cloneCmd.Stderr = os.Stderr if err = cloneCmd.Run(); err != nil { return fmt.Errorf("error cloning template repository: %s", err) } return copyTemplate(tempDir, projectPath) } func copyTemplate(src, dist string) (err error) { // 读取 go.mod 文件中的模块名称 oldModuleName, err := getModuleName(filepath.Join(src, "go.mod")) if err != nil { log.Printf("error reading module name: %s", err) os.Exit(1) } replacements := map[string]string{ oldModuleName: filepath.Base(dist), } return filepath.Walk(src, func(path string, info fs.FileInfo, err error) error { if err != nil { return err } // 获取相对路径 relPath, err := filepath.Rel(src, path) if err != nil { return err } // 获取目标路径 targetPath := filepath.Join(dist, relPath) if info.IsDir() && filepath.Base(path) == ".git" { return filepath.SkipDir } if info.IsDir() { // 创建目录 return os.MkdirAll(targetPath, info.Mode()) } return copyAndReplaceFile(path, targetPath, info.Mode(), replacements) }) } func copyAndReplaceFile(src, dist string, mode os.FileMode, replacements map[string]string) (err error) { // 读取源文件 sourceFile, err := os.Open(src) if err != nil { return err } defer sourceFile.Close() content, err := io.ReadAll(sourceFile) if err != nil { return err } newContent := string(content) for key, value := range replacements { newContent = strings.ReplaceAll(newContent, key, value) } // 读取目标文件 targetFile, err := os.OpenFile(dist, os.O_CREATE|os.O_RDWR|os.O_TRUNC, mode) if err != nil { return err } defer targetFile.Close() _, err = targetFile.WriteString(newContent) return }