package migration import ( "bufio" "fmt" "os" "path/filepath" "sort" "strconv" "strings" "time" "gorm.io/gorm" ) // Migration 表示一个数据库迁移 type Migration struct { Version string Description string Up func(*gorm.DB) error Down func(*gorm.DB) error } // Migrator 数据库迁移器 type Migrator struct { db *gorm.DB migrations []Migration tableName string } // NewMigrator 创建新的迁移器 // db: GORM数据库连接 // tableName: 存储迁移记录的表名,默认为 "schema_migrations" func NewMigrator(db *gorm.DB, tableName ...string) *Migrator { table := "schema_migrations" if len(tableName) > 0 && tableName[0] != "" { table = tableName[0] } return &Migrator{ db: db, migrations: make([]Migration, 0), tableName: table, } } // AddMigration 添加迁移 func (m *Migrator) AddMigration(migration Migration) { m.migrations = append(m.migrations, migration) } // AddMigrations 批量添加迁移 func (m *Migrator) AddMigrations(migrations ...Migration) { m.migrations = append(m.migrations, migrations...) } // initTable 初始化迁移记录表 func (m *Migrator) initTable() error { // 检查表是否存在 var exists bool err := m.db.Raw(fmt.Sprintf(` SELECT EXISTS ( SELECT FROM information_schema.tables WHERE table_schema = CURRENT_SCHEMA() AND table_name = '%s' ) `, m.tableName)).Scan(&exists).Error if err != nil { // 如果查询失败,可能是SQLite或其他数据库,尝试直接创建 exists = false } if !exists { // 创建迁移记录表(包含执行时间字段) err = m.db.Exec(fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( version VARCHAR(255) PRIMARY KEY, description VARCHAR(255), applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, execution_time INT COMMENT '执行耗时(ms)' ) `, m.tableName)).Error if err != nil { return fmt.Errorf("failed to create migration table: %w", err) } } else { // 表已存在,检查是否有 execution_time 字段(向后兼容) // 注意:这个检查可能在某些数据库中失败,但不影响功能 // 如果字段不存在,记录执行时间时会失败,但不影响迁移执行 var hasExecutionTime bool checkSQL := fmt.Sprintf(` SELECT COUNT(*) > 0 FROM information_schema.columns WHERE table_schema = CURRENT_SCHEMA() AND table_name = '%s' AND column_name = 'execution_time' `, m.tableName) err = m.db.Raw(checkSQL).Scan(&hasExecutionTime).Error if err == nil && !hasExecutionTime { // 尝试添加字段(如果失败不影响功能) _ = m.db.Exec(fmt.Sprintf(` ALTER TABLE %s ADD COLUMN execution_time INT COMMENT '执行耗时(ms)' `, m.tableName)) } } return nil } // getAppliedMigrations 获取已应用的迁移版本 func (m *Migrator) getAppliedMigrations() (map[string]bool, error) { if err := m.initTable(); err != nil { return nil, err } var versions []string err := m.db.Table(m.tableName).Select("version").Pluck("version", &versions).Error if err != nil { return nil, fmt.Errorf("failed to get applied migrations: %w", err) } applied := make(map[string]bool) for _, v := range versions { applied[v] = true } return applied, nil } // recordMigration 记录迁移 func (m *Migrator) recordMigration(version, description string, isUp bool, executionTime ...int) error { if err := m.initTable(); err != nil { return err } if isUp { // 记录迁移(包含执行时间,如果提供了) var err error if len(executionTime) > 0 && executionTime[0] > 0 { // 尝试插入执行时间(如果字段存在) err = m.db.Exec(fmt.Sprintf(` INSERT INTO %s (version, description, applied_at, execution_time) VALUES (?, ?, ?, ?) `, m.tableName), version, description, time.Now(), executionTime[0]).Error if err != nil { // 如果失败(可能是字段不存在),尝试不包含执行时间 err = m.db.Exec(fmt.Sprintf(` INSERT INTO %s (version, description, applied_at) VALUES (?, ?, ?) `, m.tableName), version, description, time.Now()).Error } } else { // 不包含执行时间 err = m.db.Exec(fmt.Sprintf(` INSERT INTO %s (version, description, applied_at) VALUES (?, ?, ?) `, m.tableName), version, description, time.Now()).Error } if err != nil { return fmt.Errorf("failed to record migration: %w", err) } } else { // 删除迁移记录 err := m.db.Exec(fmt.Sprintf("DELETE FROM %s WHERE version = ?", m.tableName), version).Error if err != nil { return fmt.Errorf("failed to remove migration record: %w", err) } } return nil } // Up 执行所有未应用的迁移 func (m *Migrator) Up() error { if err := m.initTable(); err != nil { return err } applied, err := m.getAppliedMigrations() if err != nil { return err } // 排序迁移 sort.Slice(m.migrations, func(i, j int) bool { return m.migrations[i].Version < m.migrations[j].Version }) // 执行未应用的迁移 for _, migration := range m.migrations { if applied[migration.Version] { continue } if migration.Up == nil { return fmt.Errorf("migration %s has no Up function", migration.Version) } // 记录开始时间 startTime := time.Now() // 开始事务 tx := m.db.Begin() if tx.Error != nil { return fmt.Errorf("failed to begin transaction: %w", tx.Error) } // 执行迁移 if err := migration.Up(tx); err != nil { tx.Rollback() return fmt.Errorf("failed to apply migration %s: %w", migration.Version, err) } // 计算执行时间(毫秒) executionTime := int(time.Since(startTime).Milliseconds()) // 记录迁移(包含执行时间) if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, true, executionTime); err != nil { tx.Rollback() return err } // 提交事务 if err := tx.Commit().Error; err != nil { return fmt.Errorf("failed to commit migration %s: %w", migration.Version, err) } fmt.Printf("Applied migration: %s - %s (耗时: %dms)\n", migration.Version, migration.Description, executionTime) } return nil } // Down 回滚最后一个迁移 func (m *Migrator) Down() error { if err := m.initTable(); err != nil { return err } applied, err := m.getAppliedMigrations() if err != nil { return err } // 排序迁移(倒序) sort.Slice(m.migrations, func(i, j int) bool { return m.migrations[i].Version > m.migrations[j].Version }) // 找到最后一个已应用的迁移并回滚 for _, migration := range m.migrations { if !applied[migration.Version] { continue } if migration.Down == nil { return fmt.Errorf("migration %s has no Down function", migration.Version) } // 开始事务 tx := m.db.Begin() if tx.Error != nil { return fmt.Errorf("failed to begin transaction: %w", tx.Error) } // 执行回滚 if err := migration.Down(tx); err != nil { tx.Rollback() return fmt.Errorf("failed to rollback migration %s: %w", migration.Version, err) } // 删除迁移记录 if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, false); err != nil { tx.Rollback() return err } // 提交事务 if err := tx.Commit().Error; err != nil { return fmt.Errorf("failed to commit rollback %s: %w", migration.Version, err) } fmt.Printf("Rolled back migration: %s - %s\n", migration.Version, migration.Description) return nil } return fmt.Errorf("no migrations to rollback") } // recordMigrationWithDB 使用指定的数据库连接记录迁移 func (m *Migrator) recordMigrationWithDB(db *gorm.DB, version, description string, isUp bool, executionTime ...int) error { if isUp { var err error if len(executionTime) > 0 && executionTime[0] > 0 { // 尝试插入执行时间(如果字段存在) err = db.Exec(fmt.Sprintf(` INSERT INTO %s (version, description, applied_at, execution_time) VALUES (?, ?, ?, ?) `, m.tableName), version, description, time.Now(), executionTime[0]).Error if err != nil { // 如果失败(可能是字段不存在),尝试不包含执行时间 err = db.Exec(fmt.Sprintf(` INSERT INTO %s (version, description, applied_at) VALUES (?, ?, ?) `, m.tableName), version, description, time.Now()).Error } } else { // 不包含执行时间 err = db.Exec(fmt.Sprintf(` INSERT INTO %s (version, description, applied_at) VALUES (?, ?, ?) `, m.tableName), version, description, time.Now()).Error } if err != nil { return fmt.Errorf("failed to record migration: %w", err) } } else { err := db.Exec(fmt.Sprintf("DELETE FROM %s WHERE version = ?", m.tableName), version).Error if err != nil { return fmt.Errorf("failed to remove migration record: %w", err) } } return nil } // Status 查看迁移状态 func (m *Migrator) Status() ([]MigrationStatus, error) { if err := m.initTable(); err != nil { return nil, err } applied, err := m.getAppliedMigrations() if err != nil { return nil, err } // 排序迁移 sort.Slice(m.migrations, func(i, j int) bool { return m.migrations[i].Version < m.migrations[j].Version }) status := make([]MigrationStatus, 0, len(m.migrations)) for _, migration := range m.migrations { status = append(status, MigrationStatus{ Version: migration.Version, Description: migration.Description, Applied: applied[migration.Version], }) } return status, nil } // MigrationStatus 迁移状态 type MigrationStatus struct { Version string Description string Applied bool } // splitSQL 分割SQL语句,处理多行SQL、注释等 // 支持单行注释(--)、多行注释(/* */)、按分号分割语句 func splitSQL(content string) []string { var statements []string var current strings.Builder lines := strings.Split(content, "\n") inMultiLineComment := false for _, line := range lines { trimmedLine := strings.TrimSpace(line) // 跳过空行 if trimmedLine == "" { continue } // 处理多行注释 if strings.HasPrefix(trimmedLine, "/*") { inMultiLineComment = true } if strings.HasSuffix(trimmedLine, "*/") { inMultiLineComment = false continue } if inMultiLineComment { continue } // 跳过单行注释 if strings.HasPrefix(trimmedLine, "--") { continue } // 添加到当前语句 current.WriteString(line) current.WriteString("\n") // 检查是否是完整语句(以分号结尾) if strings.HasSuffix(trimmedLine, ";") { stmt := strings.TrimSpace(current.String()) if stmt != "" && !strings.HasPrefix(stmt, "--") { statements = append(statements, stmt) } current.Reset() } } // 添加最后一个语句(如果没有分号结尾) if current.Len() > 0 { stmt := strings.TrimSpace(current.String()) if stmt != "" && !strings.HasPrefix(stmt, "--") { statements = append(statements, stmt) } } return statements } // parseMigrationFileName 解析迁移文件名,支持多种格式 // 格式1: 数字前缀 - 01_init_schema.sql // 格式2: 时间戳 - 20240101000001_create_users.sql // 格式3: 带.up后缀 - 20240101000001_create_users.up.sql // 返回: (version, description, error) func parseMigrationFileName(baseName string) (string, string, error) { // 移除扩展名 nameWithoutExt := strings.TrimSuffix(baseName, filepath.Ext(baseName)) // 移除 .up 后缀(如果存在) nameWithoutExt = strings.TrimSuffix(nameWithoutExt, ".up") // 解析版本号和描述 parts := strings.SplitN(nameWithoutExt, "_", 2) if len(parts) < 2 { // 如果只有一个部分,尝试作为版本号(向后兼容) return nameWithoutExt, baseName, nil } version := parts[0] description := strings.Join(parts[1:], "_") return version, description, nil } // LoadMigrationsFromFiles 从文件系统加载迁移文件 // dir: 迁移文件目录 // pattern: 文件命名模式,例如 "*.sql" 或 "*.up.sql" // 文件命名格式支持: // - 数字前缀: 01_init_schema.sql // - 时间戳: 20240101000001_create_users.sql // - 带.up后缀: 20240101000001_create_users.up.sql func LoadMigrationsFromFiles(dir string, pattern string) ([]Migration, error) { files, err := filepath.Glob(filepath.Join(dir, pattern)) if err != nil { return nil, fmt.Errorf("failed to glob migration files: %w", err) } migrations := make([]Migration, 0) for _, file := range files { baseName := filepath.Base(file) // 跳过 .down.sql 文件(会在处理 .up.sql 或 .sql 时自动加载) if strings.HasSuffix(baseName, ".down.sql") { continue } // 解析版本号和描述 version, description, err := parseMigrationFileName(baseName) if err != nil { return nil, fmt.Errorf("invalid migration file name format: %s: %w", baseName, err) } // 读取文件内容 content, err := os.ReadFile(file) if err != nil { return nil, fmt.Errorf("failed to read migration file %s: %w", file, err) } sqlContent := string(content) // 查找对应的 down 文件 downFile := strings.Replace(file, ".up.sql", ".down.sql", 1) downFile = strings.Replace(downFile, ".sql", ".down.sql", 1) var downSQL string if downContent, err := os.ReadFile(downFile); err == nil { downSQL = string(downContent) } // 创建迁移,使用 SQL 分割功能 migration := Migration{ Version: version, Description: description, Up: func(db *gorm.DB) error { // 分割 SQL 语句 statements := splitSQL(sqlContent) if len(statements) == 0 { return nil // 空文件,跳过 } // 执行每个 SQL 语句 // 注意:某些 DDL 语句(如 CREATE TABLE)在某些数据库中会隐式提交事务 // 因此这里不使用事务,而是逐个执行 for i, stmt := range statements { stmt = strings.TrimSpace(stmt) if stmt == "" { continue } if err := db.Exec(stmt).Error; err != nil { // 如果是表已存在的错误,记录警告但继续执行(向后兼容) errStr := err.Error() if strings.Contains(errStr, "already exists") || strings.Contains(errStr, "Duplicate") || strings.Contains(errStr, "duplicate") { fmt.Printf("Warning: SQL statement %d in migration %s: %v\n", i+1, version, err) continue } return fmt.Errorf("failed to execute SQL statement %d in migration %s: %w\nSQL: %s", i+1, version, err, stmt) } } return nil }, } if downSQL != "" { migration.Down = func(db *gorm.DB) error { // 分割 SQL 语句 statements := splitSQL(downSQL) if len(statements) == 0 { return nil } // 执行每个 SQL 语句 for i, stmt := range statements { stmt = strings.TrimSpace(stmt) if stmt == "" { continue } if err := db.Exec(stmt).Error; err != nil { return fmt.Errorf("failed to execute SQL statement %d in rollback %s: %w\nSQL: %s", i+1, version, err, stmt) } } return nil } } migrations = append(migrations, migration) } // 按版本号排序(支持数字和时间戳混合排序) sort.Slice(migrations, func(i, j int) bool { vi, vj := migrations[i].Version, migrations[j].Version // 尝试按数字排序(如果是数字前缀) if viNum, err1 := strconv.Atoi(vi); err1 == nil { if vjNum, err2 := strconv.Atoi(vj); err2 == nil { return viNum < vjNum } } // 否则按字符串排序 return vi < vj }) return migrations, nil } // GenerateVersion 生成迁移版本号(基于时间戳) func GenerateVersion() string { return strconv.FormatInt(time.Now().Unix(), 10) } // Reset 重置所有迁移(清空迁移记录表) // confirm: 确认标志,必须为true才能执行重置 // 注意:此操作会清空所有迁移记录,但不会回滚已执行的迁移操作 // 如果需要回滚迁移,请先使用Down方法逐个回滚 func (m *Migrator) Reset(confirm bool) error { if !confirm { return fmt.Errorf("reset operation requires explicit confirmation (confirm must be true)") } if err := m.initTable(); err != nil { return err } // 获取已应用的迁移数量 applied, err := m.getAppliedMigrations() if err != nil { return err } count := len(applied) if count == 0 { fmt.Println("No migrations to reset") return nil } // 清空迁移记录表 err = m.db.Exec(fmt.Sprintf("DELETE FROM %s", m.tableName)).Error if err != nil { return fmt.Errorf("failed to reset migrations: %w", err) } fmt.Printf("Reset completed: %d migration record(s) cleared\n", count) fmt.Println("WARNING: This only clears migration records, not the actual database changes.") fmt.Println("If you need to rollback database changes, use Down() method before reset.") return nil } // ResetWithConfirm 交互式重置所有迁移(带确认提示) // 会提示用户数据会被清空,需要输入确认才能执行 func (m *Migrator) ResetWithConfirm() error { if err := m.initTable(); err != nil { return err } // 获取已应用的迁移数量 applied, err := m.getAppliedMigrations() if err != nil { return err } count := len(applied) if count == 0 { fmt.Println("No migrations to reset") return nil } // 显示警告信息 fmt.Println("=" + strings.Repeat("=", 60) + "=") fmt.Println("WARNING: This operation will clear all migration records!") fmt.Println("=" + strings.Repeat("=", 60) + "=") fmt.Printf("This will clear %d migration record(s) from the database.\n", count) fmt.Println() fmt.Println("IMPORTANT NOTES:") fmt.Println(" 1. This operation only clears migration records, NOT the actual database changes.") fmt.Println(" 2. If you need to rollback database changes, use Down() method before reset.") fmt.Println(" 3. After reset, you may need to re-run migrations if the database structure changed.") fmt.Println() fmt.Print("Type 'RESET' (all caps) to confirm, or anything else to cancel: ") // 读取用户输入 reader := bufio.NewReader(os.Stdin) input, err := reader.ReadString('\n') if err != nil { return fmt.Errorf("failed to read confirmation: %w", err) } // 去除换行符和空格 input = strings.TrimSpace(input) // 检查确认 if input != "RESET" { fmt.Println("Reset operation cancelled.") return nil } // 执行重置 return m.Reset(true) } // ResetAll 重置所有迁移并回滚所有已应用的迁移 // confirm: 确认标志,必须为true才能执行 // 注意:此操作会回滚所有已应用的迁移,然后清空迁移记录 func (m *Migrator) ResetAll(confirm bool) error { if !confirm { return fmt.Errorf("reset all operation requires explicit confirmation (confirm must be true)") } if err := m.initTable(); err != nil { return err } // 获取已应用的迁移 applied, err := m.getAppliedMigrations() if err != nil { return err } count := len(applied) if count == 0 { fmt.Println("No migrations to reset") return nil } // 回滚所有已应用的迁移 fmt.Printf("Rolling back %d migration(s)...\n", count) // 排序迁移(倒序) sort.Slice(m.migrations, func(i, j int) bool { return m.migrations[i].Version > m.migrations[j].Version }) // 回滚所有已应用的迁移 for _, migration := range m.migrations { if !applied[migration.Version] { continue } if migration.Down == nil { fmt.Printf("Warning: Migration %s has no Down function, skipping rollback\n", migration.Version) continue } // 开始事务 tx := m.db.Begin() if tx.Error != nil { return fmt.Errorf("failed to begin transaction: %w", tx.Error) } // 执行回滚 if err := migration.Down(tx); err != nil { tx.Rollback() return fmt.Errorf("failed to rollback migration %s: %w", migration.Version, err) } // 删除迁移记录 if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, false); err != nil { tx.Rollback() return err } // 提交事务 if err := tx.Commit().Error; err != nil { return fmt.Errorf("failed to commit rollback %s: %w", migration.Version, err) } fmt.Printf("Rolled back migration: %s - %s\n", migration.Version, migration.Description) } fmt.Printf("Reset all completed: %d migration(s) rolled back and records cleared\n", count) return nil } // ResetAllWithConfirm 交互式重置所有迁移并回滚(带确认提示) // 会提示用户数据会被清空,需要输入确认才能执行 func (m *Migrator) ResetAllWithConfirm() error { if err := m.initTable(); err != nil { return err } // 获取已应用的迁移数量 applied, err := m.getAppliedMigrations() if err != nil { return err } count := len(applied) if count == 0 { fmt.Println("No migrations to reset") return nil } // 显示警告信息 fmt.Println("=" + strings.Repeat("=", 60) + "=") fmt.Println("WARNING: This operation will rollback ALL migrations and clear records!") fmt.Println("=" + strings.Repeat("=", 60) + "=") fmt.Printf("This will rollback %d migration(s) and clear all migration records.\n", count) fmt.Println() fmt.Println("IMPORTANT NOTES:") fmt.Println(" 1. This operation will EXECUTE Down() functions for all applied migrations.") fmt.Println(" 2. All database changes made by migrations will be REVERTED.") fmt.Println(" 3. This operation CANNOT be undone!") fmt.Println(" 4. Make sure you have a database backup before proceeding.") fmt.Println() fmt.Print("Type 'RESET ALL' (all caps) to confirm, or anything else to cancel: ") // 读取用户输入 reader := bufio.NewReader(os.Stdin) input, err := reader.ReadString('\n') if err != nil { return fmt.Errorf("failed to read confirmation: %w", err) } // 去除换行符和空格 input = strings.TrimSpace(input) // 检查确认 if input != "RESET ALL" { fmt.Println("Reset all operation cancelled.") return nil } // 执行重置 return m.ResetAll(true) }