2024-08-08 21:29:58 +08:00
|
|
|
package cmd
|
|
|
|
|
|
|
|
import (
|
2024-08-09 14:57:41 +08:00
|
|
|
"bytes"
|
2024-08-08 21:29:58 +08:00
|
|
|
"fmt"
|
2024-08-09 14:57:41 +08:00
|
|
|
"github.com/spf13/cobra"
|
2024-08-08 21:29:58 +08:00
|
|
|
"io"
|
|
|
|
"io/fs"
|
|
|
|
"log"
|
|
|
|
"os"
|
|
|
|
"os/exec"
|
|
|
|
"path/filepath"
|
|
|
|
"strings"
|
|
|
|
)
|
|
|
|
|
|
|
|
func getModuleName(goModPath string) (string, error) {
|
|
|
|
content, err := os.ReadFile(goModPath)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
lines := strings.Split(string(content), "\n")
|
|
|
|
for _, line := range lines {
|
|
|
|
if strings.HasPrefix(line, "module ") {
|
|
|
|
return strings.TrimSpace(strings.TrimPrefix(line, "module ")), nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return "", fmt.Errorf("module name not found in go.mod")
|
|
|
|
}
|
|
|
|
|
|
|
|
func handleRemoteTemplate(templateRepo, branch, projectPath string) (err error) {
|
|
|
|
// 创建临时目录
|
|
|
|
tempDir, err := os.MkdirTemp("", "template-*")
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("error creating temporary directory: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// 清理临时目录
|
|
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
|
|
|
|
// 克隆模板仓库
|
2024-08-09 14:57:41 +08:00
|
|
|
cloneCmd := exec.Command("git", "clone", "--branch", branch, templateRepo, tempDir)
|
2024-08-08 21:29:58 +08:00
|
|
|
cloneCmd.Stdout = io.Discard
|
2024-08-09 14:57:41 +08:00
|
|
|
var stderr bytes.Buffer
|
|
|
|
cloneCmd.Stderr = &stderr
|
2024-08-08 21:29:58 +08:00
|
|
|
|
|
|
|
if err = cloneCmd.Run(); err != nil {
|
2024-08-09 14:57:41 +08:00
|
|
|
return fmt.Errorf("error cloning template repository: %s\n%s", err, extractError(stderr.String()))
|
2024-08-08 21:29:58 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
return copyTemplate(tempDir, projectPath)
|
|
|
|
}
|
|
|
|
|
|
|
|
func copyTemplate(src, dist string) (err error) {
|
|
|
|
// 读取 go.mod 文件中的模块名称
|
|
|
|
oldModuleName, err := getModuleName(filepath.Join(src, "go.mod"))
|
|
|
|
if err != nil {
|
|
|
|
log.Printf("error reading module name: %s", err)
|
|
|
|
os.Exit(1)
|
|
|
|
}
|
|
|
|
replacements := map[string]string{
|
|
|
|
oldModuleName: filepath.Base(dist),
|
|
|
|
}
|
|
|
|
|
|
|
|
return filepath.Walk(src, func(path string, info fs.FileInfo, err error) error {
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// 获取相对路径
|
|
|
|
relPath, err := filepath.Rel(src, path)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// 获取目标路径
|
|
|
|
targetPath := filepath.Join(dist, relPath)
|
|
|
|
|
|
|
|
if info.IsDir() && filepath.Base(path) == ".git" {
|
|
|
|
return filepath.SkipDir
|
|
|
|
}
|
|
|
|
|
|
|
|
if info.IsDir() {
|
|
|
|
// 创建目录
|
|
|
|
return os.MkdirAll(targetPath, info.Mode())
|
|
|
|
}
|
|
|
|
|
|
|
|
return copyAndReplaceFile(path, targetPath, info.Mode(), replacements)
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
func copyAndReplaceFile(src, dist string, mode os.FileMode, replacements map[string]string) (err error) {
|
|
|
|
|
|
|
|
// 读取源文件
|
|
|
|
sourceFile, err := os.Open(src)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
defer sourceFile.Close()
|
|
|
|
|
|
|
|
content, err := io.ReadAll(sourceFile)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
newContent := string(content)
|
|
|
|
for key, value := range replacements {
|
|
|
|
newContent = strings.ReplaceAll(newContent, key, value)
|
|
|
|
}
|
|
|
|
|
|
|
|
// 读取目标文件
|
|
|
|
targetFile, err := os.OpenFile(dist, os.O_CREATE|os.O_RDWR|os.O_TRUNC, mode)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
defer targetFile.Close()
|
|
|
|
|
|
|
|
_, err = targetFile.WriteString(newContent)
|
|
|
|
return
|
|
|
|
}
|
2024-08-09 14:57:41 +08:00
|
|
|
|
|
|
|
// extractError 解析并提取实际的错误信息
|
|
|
|
func extractError(stderr string) string {
|
|
|
|
lines := strings.Split(stderr, "\n")
|
|
|
|
var errorLines []string
|
|
|
|
for _, line := range lines {
|
|
|
|
// 过滤掉非错误信息
|
|
|
|
if strings.HasPrefix(line, "fatal:") || strings.Contains(line, "error:") {
|
|
|
|
errorLines = append(errorLines, line)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return strings.Join(errorLines, "\n")
|
|
|
|
}
|
|
|
|
|
|
|
|
// initFlags 初始化通用的命令行参数
|
|
|
|
func initFlags(cmd *cobra.Command) {
|
|
|
|
cmd.Flags().StringVarP(&project, "project", "p", "", "项目名称")
|
|
|
|
cmd.Flags().StringVarP(&branch, "tag", "t", "master", "指定tag")
|
|
|
|
cmd.MarkFlagRequired("project")
|
|
|
|
}
|