调整迁移的逻辑

This commit is contained in:
2025-12-06 21:38:53 +08:00
parent 6146178111
commit 6547e7bca8
5 changed files with 515 additions and 88 deletions

View File

@@ -72,17 +72,38 @@ func (m *Migrator) initTable() error {
}
if !exists {
// 创建迁移记录表
// 创建迁移记录表(包含执行时间字段)
err = m.db.Exec(fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
version VARCHAR(255) PRIMARY KEY,
description VARCHAR(255),
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
execution_time INT COMMENT '执行耗时(ms)'
)
`, m.tableName)).Error
if err != nil {
return fmt.Errorf("failed to create migration table: %w", err)
}
} else {
// 表已存在,检查是否有 execution_time 字段(向后兼容)
// 注意:这个检查可能在某些数据库中失败,但不影响功能
// 如果字段不存在,记录执行时间时会失败,但不影响迁移执行
var hasExecutionTime bool
checkSQL := fmt.Sprintf(`
SELECT COUNT(*) > 0
FROM information_schema.columns
WHERE table_schema = CURRENT_SCHEMA()
AND table_name = '%s'
AND column_name = 'execution_time'
`, m.tableName)
err = m.db.Raw(checkSQL).Scan(&hasExecutionTime).Error
if err == nil && !hasExecutionTime {
// 尝试添加字段(如果失败不影响功能)
_ = m.db.Exec(fmt.Sprintf(`
ALTER TABLE %s
ADD COLUMN execution_time INT COMMENT '执行耗时(ms)'
`, m.tableName))
}
}
return nil
@@ -109,17 +130,34 @@ func (m *Migrator) getAppliedMigrations() (map[string]bool, error) {
}
// recordMigration 记录迁移
func (m *Migrator) recordMigration(version, description string, isUp bool) error {
func (m *Migrator) recordMigration(version, description string, isUp bool, executionTime ...int) error {
if err := m.initTable(); err != nil {
return err
}
if isUp {
// 记录迁移
err := m.db.Exec(fmt.Sprintf(`
INSERT INTO %s (version, description, applied_at)
VALUES (?, ?, ?)
`, m.tableName), version, description, time.Now()).Error
// 记录迁移(包含执行时间,如果提供了)
var err error
if len(executionTime) > 0 && executionTime[0] > 0 {
// 尝试插入执行时间(如果字段存在)
err = m.db.Exec(fmt.Sprintf(`
INSERT INTO %s (version, description, applied_at, execution_time)
VALUES (?, ?, ?, ?)
`, m.tableName), version, description, time.Now(), executionTime[0]).Error
if err != nil {
// 如果失败(可能是字段不存在),尝试不包含执行时间
err = m.db.Exec(fmt.Sprintf(`
INSERT INTO %s (version, description, applied_at)
VALUES (?, ?, ?)
`, m.tableName), version, description, time.Now()).Error
}
} else {
// 不包含执行时间
err = m.db.Exec(fmt.Sprintf(`
INSERT INTO %s (version, description, applied_at)
VALUES (?, ?, ?)
`, m.tableName), version, description, time.Now()).Error
}
if err != nil {
return fmt.Errorf("failed to record migration: %w", err)
}
@@ -160,6 +198,9 @@ func (m *Migrator) Up() error {
return fmt.Errorf("migration %s has no Up function", migration.Version)
}
// 记录开始时间
startTime := time.Now()
// 开始事务
tx := m.db.Begin()
if tx.Error != nil {
@@ -172,8 +213,11 @@ func (m *Migrator) Up() error {
return fmt.Errorf("failed to apply migration %s: %w", migration.Version, err)
}
// 记录迁移
if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, true); err != nil {
// 计算执行时间(毫秒)
executionTime := int(time.Since(startTime).Milliseconds())
// 记录迁移(包含执行时间)
if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, true, executionTime); err != nil {
tx.Rollback()
return err
}
@@ -183,7 +227,7 @@ func (m *Migrator) Up() error {
return fmt.Errorf("failed to commit migration %s: %w", migration.Version, err)
}
fmt.Printf("Applied migration: %s - %s\n", migration.Version, migration.Description)
fmt.Printf("Applied migration: %s - %s (耗时: %dms)\n", migration.Version, migration.Description, executionTime)
}
return nil
@@ -246,12 +290,29 @@ func (m *Migrator) Down() error {
}
// recordMigrationWithDB 使用指定的数据库连接记录迁移
func (m *Migrator) recordMigrationWithDB(db *gorm.DB, version, description string, isUp bool) error {
func (m *Migrator) recordMigrationWithDB(db *gorm.DB, version, description string, isUp bool, executionTime ...int) error {
if isUp {
err := db.Exec(fmt.Sprintf(`
INSERT INTO %s (version, description, applied_at)
VALUES (?, ?, ?)
`, m.tableName), version, description, time.Now()).Error
var err error
if len(executionTime) > 0 && executionTime[0] > 0 {
// 尝试插入执行时间(如果字段存在)
err = db.Exec(fmt.Sprintf(`
INSERT INTO %s (version, description, applied_at, execution_time)
VALUES (?, ?, ?, ?)
`, m.tableName), version, description, time.Now(), executionTime[0]).Error
if err != nil {
// 如果失败(可能是字段不存在),尝试不包含执行时间
err = db.Exec(fmt.Sprintf(`
INSERT INTO %s (version, description, applied_at)
VALUES (?, ?, ?)
`, m.tableName), version, description, time.Now()).Error
}
} else {
// 不包含执行时间
err = db.Exec(fmt.Sprintf(`
INSERT INTO %s (version, description, applied_at)
VALUES (?, ?, ?)
`, m.tableName), version, description, time.Now()).Error
}
if err != nil {
return fmt.Errorf("failed to record migration: %w", err)
}
@@ -300,10 +361,96 @@ type MigrationStatus struct {
Applied bool
}
// splitSQL 分割SQL语句处理多行SQL、注释等
// 支持单行注释(--)、多行注释(/* */)、按分号分割语句
func splitSQL(content string) []string {
var statements []string
var current strings.Builder
lines := strings.Split(content, "\n")
inMultiLineComment := false
for _, line := range lines {
trimmedLine := strings.TrimSpace(line)
// 跳过空行
if trimmedLine == "" {
continue
}
// 处理多行注释
if strings.HasPrefix(trimmedLine, "/*") {
inMultiLineComment = true
}
if strings.HasSuffix(trimmedLine, "*/") {
inMultiLineComment = false
continue
}
if inMultiLineComment {
continue
}
// 跳过单行注释
if strings.HasPrefix(trimmedLine, "--") {
continue
}
// 添加到当前语句
current.WriteString(line)
current.WriteString("\n")
// 检查是否是完整语句(以分号结尾)
if strings.HasSuffix(trimmedLine, ";") {
stmt := strings.TrimSpace(current.String())
if stmt != "" && !strings.HasPrefix(stmt, "--") {
statements = append(statements, stmt)
}
current.Reset()
}
}
// 添加最后一个语句(如果没有分号结尾)
if current.Len() > 0 {
stmt := strings.TrimSpace(current.String())
if stmt != "" && !strings.HasPrefix(stmt, "--") {
statements = append(statements, stmt)
}
}
return statements
}
// parseMigrationFileName 解析迁移文件名,支持多种格式
// 格式1: 数字前缀 - 01_init_schema.sql
// 格式2: 时间戳 - 20240101000001_create_users.sql
// 格式3: 带.up后缀 - 20240101000001_create_users.up.sql
// 返回: (version, description, error)
func parseMigrationFileName(baseName string) (string, string, error) {
// 移除扩展名
nameWithoutExt := strings.TrimSuffix(baseName, filepath.Ext(baseName))
// 移除 .up 后缀(如果存在)
nameWithoutExt = strings.TrimSuffix(nameWithoutExt, ".up")
// 解析版本号和描述
parts := strings.SplitN(nameWithoutExt, "_", 2)
if len(parts) < 2 {
// 如果只有一个部分,尝试作为版本号(向后兼容)
return nameWithoutExt, baseName, nil
}
version := parts[0]
description := strings.Join(parts[1:], "_")
return version, description, nil
}
// LoadMigrationsFromFiles 从文件系统加载迁移文件
// dir: 迁移文件目录
// pattern: 文件命名模式,例如 "*.sql" 或 "*.up.sql"
// 文件命名格式: {version}_{description}.sql 或 {version}_{description}.up.sql
// 文件命名格式支持:
// - 数字前缀: 01_init_schema.sql
// - 时间戳: 20240101000001_create_users.sql
// - 带.up后缀: 20240101000001_create_users.up.sql
func LoadMigrationsFromFiles(dir string, pattern string) ([]Migration, error) {
files, err := filepath.Glob(filepath.Join(dir, pattern))
if err != nil {
@@ -313,19 +460,17 @@ func LoadMigrationsFromFiles(dir string, pattern string) ([]Migration, error) {
migrations := make([]Migration, 0)
for _, file := range files {
baseName := filepath.Base(file)
// 移除扩展名
nameWithoutExt := strings.TrimSuffix(baseName, filepath.Ext(baseName))
// 移除 .up 后缀(如果存在)
nameWithoutExt = strings.TrimSuffix(nameWithoutExt, ".up")
// 解析版本号和描述
parts := strings.SplitN(nameWithoutExt, "_", 2)
if len(parts) < 2 {
return nil, fmt.Errorf("invalid migration file name format: %s (expected: {version}_{description})", baseName)
// 跳过 .down.sql 文件(会在处理 .up.sql 或 .sql 时自动加载)
if strings.HasSuffix(baseName, ".down.sql") {
continue
}
version := parts[0]
description := strings.Join(parts[1:], "_")
// 解析版本号和描述
version, description, err := parseMigrationFileName(baseName)
if err != nil {
return nil, fmt.Errorf("invalid migration file name format: %s: %w", baseName, err)
}
// 读取文件内容
content, err := os.ReadFile(file)
@@ -343,26 +488,81 @@ func LoadMigrationsFromFiles(dir string, pattern string) ([]Migration, error) {
downSQL = string(downContent)
}
// 创建迁移,使用 SQL 分割功能
migration := Migration{
Version: version,
Description: description,
Up: func(db *gorm.DB) error {
return db.Exec(sqlContent).Error
// 分割 SQL 语句
statements := splitSQL(sqlContent)
if len(statements) == 0 {
return nil // 空文件,跳过
}
// 执行每个 SQL 语句
// 注意:某些 DDL 语句(如 CREATE TABLE在某些数据库中会隐式提交事务
// 因此这里不使用事务,而是逐个执行
for i, stmt := range statements {
stmt = strings.TrimSpace(stmt)
if stmt == "" {
continue
}
if err := db.Exec(stmt).Error; err != nil {
// 如果是表已存在的错误,记录警告但继续执行(向后兼容)
errStr := err.Error()
if strings.Contains(errStr, "already exists") ||
strings.Contains(errStr, "Duplicate") ||
strings.Contains(errStr, "duplicate") {
fmt.Printf("Warning: SQL statement %d in migration %s: %v\n", i+1, version, err)
continue
}
return fmt.Errorf("failed to execute SQL statement %d in migration %s: %w\nSQL: %s", i+1, version, err, stmt)
}
}
return nil
},
}
if downSQL != "" {
migration.Down = func(db *gorm.DB) error {
return db.Exec(downSQL).Error
// 分割 SQL 语句
statements := splitSQL(downSQL)
if len(statements) == 0 {
return nil
}
// 执行每个 SQL 语句
for i, stmt := range statements {
stmt = strings.TrimSpace(stmt)
if stmt == "" {
continue
}
if err := db.Exec(stmt).Error; err != nil {
return fmt.Errorf("failed to execute SQL statement %d in rollback %s: %w\nSQL: %s", i+1, version, err, stmt)
}
}
return nil
}
}
migrations = append(migrations, migration)
}
// 按版本号排序
// 按版本号排序(支持数字和时间戳混合排序)
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].Version < migrations[j].Version
vi, vj := migrations[i].Version, migrations[j].Version
// 尝试按数字排序(如果是数字前缀)
if viNum, err1 := strconv.Atoi(vi); err1 == nil {
if vjNum, err2 := strconv.Atoi(vj); err2 == nil {
return viNum < vjNum
}
}
// 否则按字符串排序
return vi < vj
})
return migrations, nil