232 lines
6.7 KiB
Go
232 lines
6.7 KiB
Go
package migration
|
||
|
||
import (
|
||
"fmt"
|
||
"os"
|
||
"time"
|
||
|
||
"git.toowon.com/jimmy/go-common/config"
|
||
"gorm.io/driver/mysql"
|
||
"gorm.io/driver/postgres"
|
||
"gorm.io/driver/sqlite"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// RunMigrationsFromConfig 从配置文件运行迁移(便捷方法)
|
||
//
|
||
// 注意:推荐使用独立的迁移工具(templates/migrate/main.go),而不是在应用代码中直接调用。
|
||
// 独立工具可以实现零耦合、独立部署。
|
||
//
|
||
// 此方法主要用于:
|
||
// 1. 独立迁移工具内部调用(推荐)
|
||
// 2. 简单场景下在应用启动时调用(不推荐,会导致耦合)
|
||
//
|
||
// 用法:
|
||
//
|
||
// import "git.toowon.com/jimmy/go-common/migration"
|
||
// migration.RunMigrationsFromConfig("config.json", "migrations")
|
||
// // 或使用默认迁移目录
|
||
// migration.RunMigrationsFromConfig("config.json", "")
|
||
func RunMigrationsFromConfig(configFile, migrationsDir string) error {
|
||
return RunMigrationsFromConfigWithCommand(configFile, migrationsDir, "up")
|
||
}
|
||
|
||
// RunMigrationsFromConfigWithCommand 从配置文件运行迁移(支持命令,黑盒模式)
|
||
//
|
||
// 这是最简单的迁移方式,内部自动处理:
|
||
// - 配置加载(支持文件、环境变量、默认路径)
|
||
// - 数据库连接(自动识别数据库类型)
|
||
// - 迁移文件加载和执行
|
||
//
|
||
// 参数:
|
||
// - configFile: 配置文件路径,支持:
|
||
// - 空字符串:自动查找(config.json, ../config.json)
|
||
// - 环境变量 DATABASE_URL:直接使用数据库URL
|
||
// - migrationsDir: 迁移文件目录,支持:
|
||
// - 空字符串:使用默认目录 "migrations"
|
||
// - 相对路径或绝对路径
|
||
// - command: 命令,支持 "up", "down", "status"
|
||
//
|
||
// 使用示例:
|
||
//
|
||
// // 最简单:使用默认配置和默认迁移目录
|
||
// migration.RunMigrationsFromConfigWithCommand("", "", "up")
|
||
//
|
||
// // 指定配置文件,使用默认迁移目录
|
||
// migration.RunMigrationsFromConfigWithCommand("config.json", "", "up")
|
||
//
|
||
// // 指定配置和迁移目录
|
||
// migration.RunMigrationsFromConfigWithCommand("config.json", "scripts/sql", "up")
|
||
//
|
||
// // 使用环境变量
|
||
// // DATABASE_URL="mysql://..." migration.RunMigrationsFromConfigWithCommand("", "migrations", "up")
|
||
func RunMigrationsFromConfigWithCommand(configFile, migrationsDir, command string) error {
|
||
// 加载配置
|
||
cfg, err := loadConfigFromFileOrEnv(configFile)
|
||
if err != nil {
|
||
return fmt.Errorf("加载配置失败: %w", err)
|
||
}
|
||
|
||
// 连接数据库
|
||
db, err := connectDB(cfg)
|
||
if err != nil {
|
||
return fmt.Errorf("连接数据库失败: %w", err)
|
||
}
|
||
|
||
// 使用默认迁移目录(如果未指定)
|
||
if migrationsDir == "" {
|
||
migrationsDir = "migrations"
|
||
}
|
||
|
||
// 创建迁移器
|
||
migrator := NewMigrator(db)
|
||
|
||
// 加载迁移文件
|
||
migrations, err := LoadMigrationsFromFiles(migrationsDir, "*.sql")
|
||
if err != nil {
|
||
return fmt.Errorf("加载迁移文件失败: %w", err)
|
||
}
|
||
|
||
if len(migrations) == 0 {
|
||
fmt.Printf("在目录 '%s' 中没有找到迁移文件\n", migrationsDir)
|
||
return nil
|
||
}
|
||
|
||
migrator.AddMigrations(migrations...)
|
||
|
||
// 执行命令
|
||
switch command {
|
||
case "up":
|
||
if err := migrator.Up(); err != nil {
|
||
return fmt.Errorf("执行迁移失败: %w", err)
|
||
}
|
||
fmt.Println("✓ 迁移执行成功")
|
||
|
||
case "down":
|
||
if err := migrator.Down(); err != nil {
|
||
return fmt.Errorf("回滚迁移失败: %w", err)
|
||
}
|
||
fmt.Println("✓ 迁移回滚成功")
|
||
|
||
case "status":
|
||
status, err := migrator.Status()
|
||
if err != nil {
|
||
return fmt.Errorf("获取迁移状态失败: %w", err)
|
||
}
|
||
printMigrationStatus(status)
|
||
|
||
default:
|
||
return fmt.Errorf("未知命令: %s (支持: up, down, status)", command)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// loadConfigFromFileOrEnv 从文件或环境变量加载配置
|
||
func loadConfigFromFileOrEnv(configFile string) (*config.Config, error) {
|
||
// 优先从环境变量加载
|
||
if dbURL := os.Getenv("DATABASE_URL"); dbURL != "" {
|
||
return &config.Config{
|
||
Database: &config.DatabaseConfig{
|
||
DSN: dbURL,
|
||
},
|
||
}, nil
|
||
}
|
||
|
||
// 尝试从配置文件加载
|
||
if configFile != "" {
|
||
if _, err := os.Stat(configFile); err == nil {
|
||
return config.LoadFromFile(configFile)
|
||
}
|
||
}
|
||
|
||
// 尝试默认路径
|
||
defaultPaths := []string{"config.json", "../config.json"}
|
||
for _, path := range defaultPaths {
|
||
if _, err := os.Stat(path); err == nil {
|
||
return config.LoadFromFile(path)
|
||
}
|
||
}
|
||
|
||
return nil, fmt.Errorf("未找到配置文件,也未设置环境变量 DATABASE_URL")
|
||
}
|
||
|
||
// connectDB 连接数据库
|
||
// 与 factory.getDatabase 保持一致的实现,避免代码重复
|
||
func connectDB(cfg *config.Config) (*gorm.DB, error) {
|
||
if cfg.Database == nil {
|
||
return nil, fmt.Errorf("数据库配置为空")
|
||
}
|
||
|
||
dsn, err := cfg.GetDatabaseDSN()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var db *gorm.DB
|
||
switch cfg.Database.Type {
|
||
case "mysql":
|
||
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
|
||
case "postgres":
|
||
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||
case "sqlite":
|
||
db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||
default:
|
||
return nil, fmt.Errorf("不支持的数据库类型: %s", cfg.Database.Type)
|
||
}
|
||
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 配置连接池(与 factory.getDatabase 保持一致)
|
||
sqlDB, err := db.DB()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 使用配置文件中的连接池参数,如果没有配置则使用默认值
|
||
if cfg.Database.MaxOpenConns > 0 {
|
||
sqlDB.SetMaxOpenConns(cfg.Database.MaxOpenConns)
|
||
} else {
|
||
sqlDB.SetMaxOpenConns(10) // 默认值
|
||
}
|
||
|
||
if cfg.Database.MaxIdleConns > 0 {
|
||
sqlDB.SetMaxIdleConns(cfg.Database.MaxIdleConns)
|
||
} else {
|
||
sqlDB.SetMaxIdleConns(5) // 默认值
|
||
}
|
||
|
||
if cfg.Database.ConnMaxLifetime > 0 {
|
||
sqlDB.SetConnMaxLifetime(time.Duration(cfg.Database.ConnMaxLifetime) * time.Second)
|
||
} else {
|
||
sqlDB.SetConnMaxLifetime(time.Hour) // 默认值
|
||
}
|
||
|
||
return db, nil
|
||
}
|
||
|
||
// printMigrationStatus 打印迁移状态
|
||
func printMigrationStatus(status []MigrationStatus) {
|
||
if len(status) == 0 {
|
||
fmt.Println("没有找到迁移")
|
||
return
|
||
}
|
||
|
||
fmt.Println("\n迁移状态:")
|
||
fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
|
||
fmt.Printf("%-20s %-40s %-10s\n", "版本", "描述", "状态")
|
||
fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
|
||
|
||
for _, s := range status {
|
||
statusText := "待执行"
|
||
if s.Applied {
|
||
statusText = "✓ 已应用"
|
||
}
|
||
fmt.Printf("%-20s %-40s %-10s\n", s.Version, s.Description, statusText)
|
||
}
|
||
fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
|
||
fmt.Println()
|
||
}
|