调整迁移的逻辑
This commit is contained in:
@@ -25,46 +25,41 @@ import (
|
||||
//
|
||||
// import "git.toowon.com/jimmy/go-common/migration"
|
||||
// migration.RunMigrationsFromConfig("config.json", "migrations")
|
||||
// // 或使用默认迁移目录
|
||||
// migration.RunMigrationsFromConfig("config.json", "")
|
||||
func RunMigrationsFromConfig(configFile, migrationsDir string) error {
|
||||
// 加载配置
|
||||
cfg, err := loadConfigFromFileOrEnv(configFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("加载配置失败: %w", err)
|
||||
}
|
||||
|
||||
// 连接数据库
|
||||
db, err := connectDB(cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("连接数据库失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建迁移器
|
||||
migrator := NewMigrator(db)
|
||||
|
||||
// 加载迁移文件
|
||||
migrations, err := LoadMigrationsFromFiles(migrationsDir, "*.sql")
|
||||
if err != nil {
|
||||
return fmt.Errorf("加载迁移文件失败: %w", err)
|
||||
}
|
||||
|
||||
if len(migrations) == 0 {
|
||||
fmt.Printf("在目录 '%s' 中没有找到迁移文件\n", migrationsDir)
|
||||
return nil
|
||||
}
|
||||
|
||||
migrator.AddMigrations(migrations...)
|
||||
|
||||
// 执行迁移
|
||||
if err := migrator.Up(); err != nil {
|
||||
return fmt.Errorf("执行迁移失败: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("✓ 迁移执行成功")
|
||||
return nil
|
||||
return RunMigrationsFromConfigWithCommand(configFile, migrationsDir, "up")
|
||||
}
|
||||
|
||||
// RunMigrationsFromConfigWithCommand 从配置文件运行迁移(支持命令)
|
||||
// command: "up", "down", "status"
|
||||
// RunMigrationsFromConfigWithCommand 从配置文件运行迁移(支持命令,黑盒模式)
|
||||
//
|
||||
// 这是最简单的迁移方式,内部自动处理:
|
||||
// - 配置加载(支持文件、环境变量、默认路径)
|
||||
// - 数据库连接(自动识别数据库类型)
|
||||
// - 迁移文件加载和执行
|
||||
//
|
||||
// 参数:
|
||||
// - configFile: 配置文件路径,支持:
|
||||
// - 空字符串:自动查找(config.json, ../config.json)
|
||||
// - 环境变量 DATABASE_URL:直接使用数据库URL
|
||||
// - migrationsDir: 迁移文件目录,支持:
|
||||
// - 空字符串:使用默认目录 "migrations"
|
||||
// - 相对路径或绝对路径
|
||||
// - command: 命令,支持 "up", "down", "status"
|
||||
//
|
||||
// 使用示例:
|
||||
//
|
||||
// // 最简单:使用默认配置和默认迁移目录
|
||||
// migration.RunMigrationsFromConfigWithCommand("", "", "up")
|
||||
//
|
||||
// // 指定配置文件,使用默认迁移目录
|
||||
// migration.RunMigrationsFromConfigWithCommand("config.json", "", "up")
|
||||
//
|
||||
// // 指定配置和迁移目录
|
||||
// migration.RunMigrationsFromConfigWithCommand("config.json", "scripts/sql", "up")
|
||||
//
|
||||
// // 使用环境变量
|
||||
// // DATABASE_URL="mysql://..." migration.RunMigrationsFromConfigWithCommand("", "migrations", "up")
|
||||
func RunMigrationsFromConfigWithCommand(configFile, migrationsDir, command string) error {
|
||||
// 加载配置
|
||||
cfg, err := loadConfigFromFileOrEnv(configFile)
|
||||
@@ -78,6 +73,11 @@ func RunMigrationsFromConfigWithCommand(configFile, migrationsDir, command strin
|
||||
return fmt.Errorf("连接数据库失败: %w", err)
|
||||
}
|
||||
|
||||
// 使用默认迁移目录(如果未指定)
|
||||
if migrationsDir == "" {
|
||||
migrationsDir = "migrations"
|
||||
}
|
||||
|
||||
// 创建迁移器
|
||||
migrator := NewMigrator(db)
|
||||
|
||||
@@ -152,6 +152,7 @@ func loadConfigFromFileOrEnv(configFile string) (*config.Config, error) {
|
||||
}
|
||||
|
||||
// connectDB 连接数据库
|
||||
// 与 factory.getDatabase 保持一致的实现,避免代码重复
|
||||
func connectDB(cfg *config.Config) (*gorm.DB, error) {
|
||||
if cfg.Database == nil {
|
||||
return nil, fmt.Errorf("数据库配置为空")
|
||||
@@ -178,15 +179,30 @@ func connectDB(cfg *config.Config) (*gorm.DB, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 配置连接池
|
||||
// 配置连接池(与 factory.getDatabase 保持一致)
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sqlDB.SetMaxOpenConns(10)
|
||||
sqlDB.SetMaxIdleConns(5)
|
||||
sqlDB.SetConnMaxLifetime(time.Hour)
|
||||
// 使用配置文件中的连接池参数,如果没有配置则使用默认值
|
||||
if cfg.Database.MaxOpenConns > 0 {
|
||||
sqlDB.SetMaxOpenConns(cfg.Database.MaxOpenConns)
|
||||
} else {
|
||||
sqlDB.SetMaxOpenConns(10) // 默认值
|
||||
}
|
||||
|
||||
if cfg.Database.MaxIdleConns > 0 {
|
||||
sqlDB.SetMaxIdleConns(cfg.Database.MaxIdleConns)
|
||||
} else {
|
||||
sqlDB.SetMaxIdleConns(5) // 默认值
|
||||
}
|
||||
|
||||
if cfg.Database.ConnMaxLifetime > 0 {
|
||||
sqlDB.SetConnMaxLifetime(time.Duration(cfg.Database.ConnMaxLifetime) * time.Second)
|
||||
} else {
|
||||
sqlDB.SetConnMaxLifetime(time.Hour) // 默认值
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
@@ -213,4 +229,3 @@ func printMigrationStatus(status []MigrationStatus) {
|
||||
fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
|
||||
@@ -72,17 +72,38 @@ func (m *Migrator) initTable() error {
|
||||
}
|
||||
|
||||
if !exists {
|
||||
// 创建迁移记录表
|
||||
// 创建迁移记录表(包含执行时间字段)
|
||||
err = m.db.Exec(fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s (
|
||||
version VARCHAR(255) PRIMARY KEY,
|
||||
description VARCHAR(255),
|
||||
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
execution_time INT COMMENT '执行耗时(ms)'
|
||||
)
|
||||
`, m.tableName)).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migration table: %w", err)
|
||||
}
|
||||
} else {
|
||||
// 表已存在,检查是否有 execution_time 字段(向后兼容)
|
||||
// 注意:这个检查可能在某些数据库中失败,但不影响功能
|
||||
// 如果字段不存在,记录执行时间时会失败,但不影响迁移执行
|
||||
var hasExecutionTime bool
|
||||
checkSQL := fmt.Sprintf(`
|
||||
SELECT COUNT(*) > 0
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = CURRENT_SCHEMA()
|
||||
AND table_name = '%s'
|
||||
AND column_name = 'execution_time'
|
||||
`, m.tableName)
|
||||
err = m.db.Raw(checkSQL).Scan(&hasExecutionTime).Error
|
||||
if err == nil && !hasExecutionTime {
|
||||
// 尝试添加字段(如果失败不影响功能)
|
||||
_ = m.db.Exec(fmt.Sprintf(`
|
||||
ALTER TABLE %s
|
||||
ADD COLUMN execution_time INT COMMENT '执行耗时(ms)'
|
||||
`, m.tableName))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -109,17 +130,34 @@ func (m *Migrator) getAppliedMigrations() (map[string]bool, error) {
|
||||
}
|
||||
|
||||
// recordMigration 记录迁移
|
||||
func (m *Migrator) recordMigration(version, description string, isUp bool) error {
|
||||
func (m *Migrator) recordMigration(version, description string, isUp bool, executionTime ...int) error {
|
||||
if err := m.initTable(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isUp {
|
||||
// 记录迁移
|
||||
err := m.db.Exec(fmt.Sprintf(`
|
||||
INSERT INTO %s (version, description, applied_at)
|
||||
VALUES (?, ?, ?)
|
||||
`, m.tableName), version, description, time.Now()).Error
|
||||
// 记录迁移(包含执行时间,如果提供了)
|
||||
var err error
|
||||
if len(executionTime) > 0 && executionTime[0] > 0 {
|
||||
// 尝试插入执行时间(如果字段存在)
|
||||
err = m.db.Exec(fmt.Sprintf(`
|
||||
INSERT INTO %s (version, description, applied_at, execution_time)
|
||||
VALUES (?, ?, ?, ?)
|
||||
`, m.tableName), version, description, time.Now(), executionTime[0]).Error
|
||||
if err != nil {
|
||||
// 如果失败(可能是字段不存在),尝试不包含执行时间
|
||||
err = m.db.Exec(fmt.Sprintf(`
|
||||
INSERT INTO %s (version, description, applied_at)
|
||||
VALUES (?, ?, ?)
|
||||
`, m.tableName), version, description, time.Now()).Error
|
||||
}
|
||||
} else {
|
||||
// 不包含执行时间
|
||||
err = m.db.Exec(fmt.Sprintf(`
|
||||
INSERT INTO %s (version, description, applied_at)
|
||||
VALUES (?, ?, ?)
|
||||
`, m.tableName), version, description, time.Now()).Error
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to record migration: %w", err)
|
||||
}
|
||||
@@ -160,6 +198,9 @@ func (m *Migrator) Up() error {
|
||||
return fmt.Errorf("migration %s has no Up function", migration.Version)
|
||||
}
|
||||
|
||||
// 记录开始时间
|
||||
startTime := time.Now()
|
||||
|
||||
// 开始事务
|
||||
tx := m.db.Begin()
|
||||
if tx.Error != nil {
|
||||
@@ -172,8 +213,11 @@ func (m *Migrator) Up() error {
|
||||
return fmt.Errorf("failed to apply migration %s: %w", migration.Version, err)
|
||||
}
|
||||
|
||||
// 记录迁移
|
||||
if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, true); err != nil {
|
||||
// 计算执行时间(毫秒)
|
||||
executionTime := int(time.Since(startTime).Milliseconds())
|
||||
|
||||
// 记录迁移(包含执行时间)
|
||||
if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, true, executionTime); err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
@@ -183,7 +227,7 @@ func (m *Migrator) Up() error {
|
||||
return fmt.Errorf("failed to commit migration %s: %w", migration.Version, err)
|
||||
}
|
||||
|
||||
fmt.Printf("Applied migration: %s - %s\n", migration.Version, migration.Description)
|
||||
fmt.Printf("Applied migration: %s - %s (耗时: %dms)\n", migration.Version, migration.Description, executionTime)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -246,12 +290,29 @@ func (m *Migrator) Down() error {
|
||||
}
|
||||
|
||||
// recordMigrationWithDB 使用指定的数据库连接记录迁移
|
||||
func (m *Migrator) recordMigrationWithDB(db *gorm.DB, version, description string, isUp bool) error {
|
||||
func (m *Migrator) recordMigrationWithDB(db *gorm.DB, version, description string, isUp bool, executionTime ...int) error {
|
||||
if isUp {
|
||||
err := db.Exec(fmt.Sprintf(`
|
||||
INSERT INTO %s (version, description, applied_at)
|
||||
VALUES (?, ?, ?)
|
||||
`, m.tableName), version, description, time.Now()).Error
|
||||
var err error
|
||||
if len(executionTime) > 0 && executionTime[0] > 0 {
|
||||
// 尝试插入执行时间(如果字段存在)
|
||||
err = db.Exec(fmt.Sprintf(`
|
||||
INSERT INTO %s (version, description, applied_at, execution_time)
|
||||
VALUES (?, ?, ?, ?)
|
||||
`, m.tableName), version, description, time.Now(), executionTime[0]).Error
|
||||
if err != nil {
|
||||
// 如果失败(可能是字段不存在),尝试不包含执行时间
|
||||
err = db.Exec(fmt.Sprintf(`
|
||||
INSERT INTO %s (version, description, applied_at)
|
||||
VALUES (?, ?, ?)
|
||||
`, m.tableName), version, description, time.Now()).Error
|
||||
}
|
||||
} else {
|
||||
// 不包含执行时间
|
||||
err = db.Exec(fmt.Sprintf(`
|
||||
INSERT INTO %s (version, description, applied_at)
|
||||
VALUES (?, ?, ?)
|
||||
`, m.tableName), version, description, time.Now()).Error
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to record migration: %w", err)
|
||||
}
|
||||
@@ -300,10 +361,96 @@ type MigrationStatus struct {
|
||||
Applied bool
|
||||
}
|
||||
|
||||
// splitSQL 分割SQL语句,处理多行SQL、注释等
|
||||
// 支持单行注释(--)、多行注释(/* */)、按分号分割语句
|
||||
func splitSQL(content string) []string {
|
||||
var statements []string
|
||||
var current strings.Builder
|
||||
|
||||
lines := strings.Split(content, "\n")
|
||||
inMultiLineComment := false
|
||||
|
||||
for _, line := range lines {
|
||||
trimmedLine := strings.TrimSpace(line)
|
||||
|
||||
// 跳过空行
|
||||
if trimmedLine == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理多行注释
|
||||
if strings.HasPrefix(trimmedLine, "/*") {
|
||||
inMultiLineComment = true
|
||||
}
|
||||
if strings.HasSuffix(trimmedLine, "*/") {
|
||||
inMultiLineComment = false
|
||||
continue
|
||||
}
|
||||
if inMultiLineComment {
|
||||
continue
|
||||
}
|
||||
|
||||
// 跳过单行注释
|
||||
if strings.HasPrefix(trimmedLine, "--") {
|
||||
continue
|
||||
}
|
||||
|
||||
// 添加到当前语句
|
||||
current.WriteString(line)
|
||||
current.WriteString("\n")
|
||||
|
||||
// 检查是否是完整语句(以分号结尾)
|
||||
if strings.HasSuffix(trimmedLine, ";") {
|
||||
stmt := strings.TrimSpace(current.String())
|
||||
if stmt != "" && !strings.HasPrefix(stmt, "--") {
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
current.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
// 添加最后一个语句(如果没有分号结尾)
|
||||
if current.Len() > 0 {
|
||||
stmt := strings.TrimSpace(current.String())
|
||||
if stmt != "" && !strings.HasPrefix(stmt, "--") {
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
return statements
|
||||
}
|
||||
|
||||
// parseMigrationFileName 解析迁移文件名,支持多种格式
|
||||
// 格式1: 数字前缀 - 01_init_schema.sql
|
||||
// 格式2: 时间戳 - 20240101000001_create_users.sql
|
||||
// 格式3: 带.up后缀 - 20240101000001_create_users.up.sql
|
||||
// 返回: (version, description, error)
|
||||
func parseMigrationFileName(baseName string) (string, string, error) {
|
||||
// 移除扩展名
|
||||
nameWithoutExt := strings.TrimSuffix(baseName, filepath.Ext(baseName))
|
||||
// 移除 .up 后缀(如果存在)
|
||||
nameWithoutExt = strings.TrimSuffix(nameWithoutExt, ".up")
|
||||
|
||||
// 解析版本号和描述
|
||||
parts := strings.SplitN(nameWithoutExt, "_", 2)
|
||||
if len(parts) < 2 {
|
||||
// 如果只有一个部分,尝试作为版本号(向后兼容)
|
||||
return nameWithoutExt, baseName, nil
|
||||
}
|
||||
|
||||
version := parts[0]
|
||||
description := strings.Join(parts[1:], "_")
|
||||
|
||||
return version, description, nil
|
||||
}
|
||||
|
||||
// LoadMigrationsFromFiles 从文件系统加载迁移文件
|
||||
// dir: 迁移文件目录
|
||||
// pattern: 文件命名模式,例如 "*.sql" 或 "*.up.sql"
|
||||
// 文件命名格式: {version}_{description}.sql 或 {version}_{description}.up.sql
|
||||
// 文件命名格式支持:
|
||||
// - 数字前缀: 01_init_schema.sql
|
||||
// - 时间戳: 20240101000001_create_users.sql
|
||||
// - 带.up后缀: 20240101000001_create_users.up.sql
|
||||
func LoadMigrationsFromFiles(dir string, pattern string) ([]Migration, error) {
|
||||
files, err := filepath.Glob(filepath.Join(dir, pattern))
|
||||
if err != nil {
|
||||
@@ -313,19 +460,17 @@ func LoadMigrationsFromFiles(dir string, pattern string) ([]Migration, error) {
|
||||
migrations := make([]Migration, 0)
|
||||
for _, file := range files {
|
||||
baseName := filepath.Base(file)
|
||||
// 移除扩展名
|
||||
nameWithoutExt := strings.TrimSuffix(baseName, filepath.Ext(baseName))
|
||||
// 移除 .up 后缀(如果存在)
|
||||
nameWithoutExt = strings.TrimSuffix(nameWithoutExt, ".up")
|
||||
|
||||
// 解析版本号和描述
|
||||
parts := strings.SplitN(nameWithoutExt, "_", 2)
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("invalid migration file name format: %s (expected: {version}_{description})", baseName)
|
||||
// 跳过 .down.sql 文件(会在处理 .up.sql 或 .sql 时自动加载)
|
||||
if strings.HasSuffix(baseName, ".down.sql") {
|
||||
continue
|
||||
}
|
||||
|
||||
version := parts[0]
|
||||
description := strings.Join(parts[1:], "_")
|
||||
// 解析版本号和描述
|
||||
version, description, err := parseMigrationFileName(baseName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid migration file name format: %s: %w", baseName, err)
|
||||
}
|
||||
|
||||
// 读取文件内容
|
||||
content, err := os.ReadFile(file)
|
||||
@@ -343,26 +488,81 @@ func LoadMigrationsFromFiles(dir string, pattern string) ([]Migration, error) {
|
||||
downSQL = string(downContent)
|
||||
}
|
||||
|
||||
// 创建迁移,使用 SQL 分割功能
|
||||
migration := Migration{
|
||||
Version: version,
|
||||
Description: description,
|
||||
Up: func(db *gorm.DB) error {
|
||||
return db.Exec(sqlContent).Error
|
||||
// 分割 SQL 语句
|
||||
statements := splitSQL(sqlContent)
|
||||
if len(statements) == 0 {
|
||||
return nil // 空文件,跳过
|
||||
}
|
||||
|
||||
// 执行每个 SQL 语句
|
||||
// 注意:某些 DDL 语句(如 CREATE TABLE)在某些数据库中会隐式提交事务
|
||||
// 因此这里不使用事务,而是逐个执行
|
||||
for i, stmt := range statements {
|
||||
stmt = strings.TrimSpace(stmt)
|
||||
if stmt == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := db.Exec(stmt).Error; err != nil {
|
||||
// 如果是表已存在的错误,记录警告但继续执行(向后兼容)
|
||||
errStr := err.Error()
|
||||
if strings.Contains(errStr, "already exists") ||
|
||||
strings.Contains(errStr, "Duplicate") ||
|
||||
strings.Contains(errStr, "duplicate") {
|
||||
fmt.Printf("Warning: SQL statement %d in migration %s: %v\n", i+1, version, err)
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("failed to execute SQL statement %d in migration %s: %w\nSQL: %s", i+1, version, err, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
if downSQL != "" {
|
||||
migration.Down = func(db *gorm.DB) error {
|
||||
return db.Exec(downSQL).Error
|
||||
// 分割 SQL 语句
|
||||
statements := splitSQL(downSQL)
|
||||
if len(statements) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 执行每个 SQL 语句
|
||||
for i, stmt := range statements {
|
||||
stmt = strings.TrimSpace(stmt)
|
||||
if stmt == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := db.Exec(stmt).Error; err != nil {
|
||||
return fmt.Errorf("failed to execute SQL statement %d in rollback %s: %w\nSQL: %s", i+1, version, err, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
migrations = append(migrations, migration)
|
||||
}
|
||||
|
||||
// 按版本号排序
|
||||
// 按版本号排序(支持数字和时间戳混合排序)
|
||||
sort.Slice(migrations, func(i, j int) bool {
|
||||
return migrations[i].Version < migrations[j].Version
|
||||
vi, vj := migrations[i].Version, migrations[j].Version
|
||||
// 尝试按数字排序(如果是数字前缀)
|
||||
if viNum, err1 := strconv.Atoi(vi); err1 == nil {
|
||||
if vjNum, err2 := strconv.Atoi(vj); err2 == nil {
|
||||
return viNum < vjNum
|
||||
}
|
||||
}
|
||||
// 否则按字符串排序
|
||||
return vi < vj
|
||||
})
|
||||
|
||||
return migrations, nil
|
||||
|
||||
Reference in New Issue
Block a user