package migration import ( "bufio" "fmt" "os" "path/filepath" "sort" "strconv" "strings" "time" "gorm.io/gorm" ) // Migration 表示一个数据库迁移 type Migration struct { Version string Description string Up func(*gorm.DB) error Down func(*gorm.DB) error } // Migrator 数据库迁移器 type Migrator struct { db *gorm.DB migrations []Migration tableName string } // NewMigrator 创建新的迁移器 // db: GORM数据库连接 // tableName: 存储迁移记录的表名,默认为 "schema_migrations" func NewMigrator(db *gorm.DB, tableName ...string) *Migrator { table := "schema_migrations" if len(tableName) > 0 && tableName[0] != "" { table = tableName[0] } return &Migrator{ db: db, migrations: make([]Migration, 0), tableName: table, } } // AddMigration 添加迁移 func (m *Migrator) AddMigration(migration Migration) { m.migrations = append(m.migrations, migration) } // AddMigrations 批量添加迁移 func (m *Migrator) AddMigrations(migrations ...Migration) { m.migrations = append(m.migrations, migrations...) } // initTable 初始化迁移记录表 func (m *Migrator) initTable() error { // 检查表是否存在 var exists bool err := m.db.Raw(fmt.Sprintf(` SELECT EXISTS ( SELECT FROM information_schema.tables WHERE table_schema = CURRENT_SCHEMA() AND table_name = '%s' ) `, m.tableName)).Scan(&exists).Error if err != nil { // 如果查询失败,可能是SQLite或其他数据库,尝试直接创建 exists = false } if !exists { // 创建迁移记录表 err = m.db.Exec(fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( version VARCHAR(255) PRIMARY KEY, description VARCHAR(255), applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) `, m.tableName)).Error if err != nil { return fmt.Errorf("failed to create migration table: %w", err) } } return nil } // getAppliedMigrations 获取已应用的迁移版本 func (m *Migrator) getAppliedMigrations() (map[string]bool, error) { if err := m.initTable(); err != nil { return nil, err } var versions []string err := m.db.Table(m.tableName).Select("version").Pluck("version", &versions).Error if err != nil { return nil, fmt.Errorf("failed to get applied migrations: %w", err) } applied := make(map[string]bool) for _, v := range versions { applied[v] = true } return applied, nil } // recordMigration 记录迁移 func (m *Migrator) recordMigration(version, description string, isUp bool) error { if err := m.initTable(); err != nil { return err } if isUp { // 记录迁移 err := m.db.Exec(fmt.Sprintf(` INSERT INTO %s (version, description, applied_at) VALUES (?, ?, ?) `, m.tableName), version, description, time.Now()).Error if err != nil { return fmt.Errorf("failed to record migration: %w", err) } } else { // 删除迁移记录 err := m.db.Exec(fmt.Sprintf("DELETE FROM %s WHERE version = ?", m.tableName), version).Error if err != nil { return fmt.Errorf("failed to remove migration record: %w", err) } } return nil } // Up 执行所有未应用的迁移 func (m *Migrator) Up() error { if err := m.initTable(); err != nil { return err } applied, err := m.getAppliedMigrations() if err != nil { return err } // 排序迁移 sort.Slice(m.migrations, func(i, j int) bool { return m.migrations[i].Version < m.migrations[j].Version }) // 执行未应用的迁移 for _, migration := range m.migrations { if applied[migration.Version] { continue } if migration.Up == nil { return fmt.Errorf("migration %s has no Up function", migration.Version) } // 开始事务 tx := m.db.Begin() if tx.Error != nil { return fmt.Errorf("failed to begin transaction: %w", tx.Error) } // 执行迁移 if err := migration.Up(tx); err != nil { tx.Rollback() return fmt.Errorf("failed to apply migration %s: %w", migration.Version, err) } // 记录迁移 if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, true); err != nil { tx.Rollback() return err } // 提交事务 if err := tx.Commit().Error; err != nil { return fmt.Errorf("failed to commit migration %s: %w", migration.Version, err) } fmt.Printf("Applied migration: %s - %s\n", migration.Version, migration.Description) } return nil } // Down 回滚最后一个迁移 func (m *Migrator) Down() error { if err := m.initTable(); err != nil { return err } applied, err := m.getAppliedMigrations() if err != nil { return err } // 排序迁移(倒序) sort.Slice(m.migrations, func(i, j int) bool { return m.migrations[i].Version > m.migrations[j].Version }) // 找到最后一个已应用的迁移并回滚 for _, migration := range m.migrations { if !applied[migration.Version] { continue } if migration.Down == nil { return fmt.Errorf("migration %s has no Down function", migration.Version) } // 开始事务 tx := m.db.Begin() if tx.Error != nil { return fmt.Errorf("failed to begin transaction: %w", tx.Error) } // 执行回滚 if err := migration.Down(tx); err != nil { tx.Rollback() return fmt.Errorf("failed to rollback migration %s: %w", migration.Version, err) } // 删除迁移记录 if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, false); err != nil { tx.Rollback() return err } // 提交事务 if err := tx.Commit().Error; err != nil { return fmt.Errorf("failed to commit rollback %s: %w", migration.Version, err) } fmt.Printf("Rolled back migration: %s - %s\n", migration.Version, migration.Description) return nil } return fmt.Errorf("no migrations to rollback") } // recordMigrationWithDB 使用指定的数据库连接记录迁移 func (m *Migrator) recordMigrationWithDB(db *gorm.DB, version, description string, isUp bool) error { if isUp { err := db.Exec(fmt.Sprintf(` INSERT INTO %s (version, description, applied_at) VALUES (?, ?, ?) `, m.tableName), version, description, time.Now()).Error if err != nil { return fmt.Errorf("failed to record migration: %w", err) } } else { err := db.Exec(fmt.Sprintf("DELETE FROM %s WHERE version = ?", m.tableName), version).Error if err != nil { return fmt.Errorf("failed to remove migration record: %w", err) } } return nil } // Status 查看迁移状态 func (m *Migrator) Status() ([]MigrationStatus, error) { if err := m.initTable(); err != nil { return nil, err } applied, err := m.getAppliedMigrations() if err != nil { return nil, err } // 排序迁移 sort.Slice(m.migrations, func(i, j int) bool { return m.migrations[i].Version < m.migrations[j].Version }) status := make([]MigrationStatus, 0, len(m.migrations)) for _, migration := range m.migrations { status = append(status, MigrationStatus{ Version: migration.Version, Description: migration.Description, Applied: applied[migration.Version], }) } return status, nil } // MigrationStatus 迁移状态 type MigrationStatus struct { Version string Description string Applied bool } // LoadMigrationsFromFiles 从文件系统加载迁移文件 // dir: 迁移文件目录 // pattern: 文件命名模式,例如 "*.sql" 或 "*.up.sql" // 文件命名格式: {version}_{description}.sql 或 {version}_{description}.up.sql func LoadMigrationsFromFiles(dir string, pattern string) ([]Migration, error) { files, err := filepath.Glob(filepath.Join(dir, pattern)) if err != nil { return nil, fmt.Errorf("failed to glob migration files: %w", err) } migrations := make([]Migration, 0) for _, file := range files { baseName := filepath.Base(file) // 移除扩展名 nameWithoutExt := strings.TrimSuffix(baseName, filepath.Ext(baseName)) // 移除 .up 后缀(如果存在) nameWithoutExt = strings.TrimSuffix(nameWithoutExt, ".up") // 解析版本号和描述 parts := strings.SplitN(nameWithoutExt, "_", 2) if len(parts) < 2 { return nil, fmt.Errorf("invalid migration file name format: %s (expected: {version}_{description})", baseName) } version := parts[0] description := strings.Join(parts[1:], "_") // 读取文件内容 content, err := os.ReadFile(file) if err != nil { return nil, fmt.Errorf("failed to read migration file %s: %w", file, err) } sqlContent := string(content) // 查找对应的 down 文件 downFile := strings.Replace(file, ".up.sql", ".down.sql", 1) downFile = strings.Replace(downFile, ".sql", ".down.sql", 1) var downSQL string if downContent, err := os.ReadFile(downFile); err == nil { downSQL = string(downContent) } migration := Migration{ Version: version, Description: description, Up: func(db *gorm.DB) error { return db.Exec(sqlContent).Error }, } if downSQL != "" { migration.Down = func(db *gorm.DB) error { return db.Exec(downSQL).Error } } migrations = append(migrations, migration) } // 按版本号排序 sort.Slice(migrations, func(i, j int) bool { return migrations[i].Version < migrations[j].Version }) return migrations, nil } // GenerateVersion 生成迁移版本号(基于时间戳) func GenerateVersion() string { return strconv.FormatInt(time.Now().Unix(), 10) } // Reset 重置所有迁移(清空迁移记录表) // confirm: 确认标志,必须为true才能执行重置 // 注意:此操作会清空所有迁移记录,但不会回滚已执行的迁移操作 // 如果需要回滚迁移,请先使用Down方法逐个回滚 func (m *Migrator) Reset(confirm bool) error { if !confirm { return fmt.Errorf("reset operation requires explicit confirmation (confirm must be true)") } if err := m.initTable(); err != nil { return err } // 获取已应用的迁移数量 applied, err := m.getAppliedMigrations() if err != nil { return err } count := len(applied) if count == 0 { fmt.Println("No migrations to reset") return nil } // 清空迁移记录表 err = m.db.Exec(fmt.Sprintf("DELETE FROM %s", m.tableName)).Error if err != nil { return fmt.Errorf("failed to reset migrations: %w", err) } fmt.Printf("Reset completed: %d migration record(s) cleared\n", count) fmt.Println("WARNING: This only clears migration records, not the actual database changes.") fmt.Println("If you need to rollback database changes, use Down() method before reset.") return nil } // ResetWithConfirm 交互式重置所有迁移(带确认提示) // 会提示用户数据会被清空,需要输入确认才能执行 func (m *Migrator) ResetWithConfirm() error { if err := m.initTable(); err != nil { return err } // 获取已应用的迁移数量 applied, err := m.getAppliedMigrations() if err != nil { return err } count := len(applied) if count == 0 { fmt.Println("No migrations to reset") return nil } // 显示警告信息 fmt.Println("=" + strings.Repeat("=", 60) + "=") fmt.Println("WARNING: This operation will clear all migration records!") fmt.Println("=" + strings.Repeat("=", 60) + "=") fmt.Printf("This will clear %d migration record(s) from the database.\n", count) fmt.Println() fmt.Println("IMPORTANT NOTES:") fmt.Println(" 1. This operation only clears migration records, NOT the actual database changes.") fmt.Println(" 2. If you need to rollback database changes, use Down() method before reset.") fmt.Println(" 3. After reset, you may need to re-run migrations if the database structure changed.") fmt.Println() fmt.Print("Type 'RESET' (all caps) to confirm, or anything else to cancel: ") // 读取用户输入 reader := bufio.NewReader(os.Stdin) input, err := reader.ReadString('\n') if err != nil { return fmt.Errorf("failed to read confirmation: %w", err) } // 去除换行符和空格 input = strings.TrimSpace(input) // 检查确认 if input != "RESET" { fmt.Println("Reset operation cancelled.") return nil } // 执行重置 return m.Reset(true) } // ResetAll 重置所有迁移并回滚所有已应用的迁移 // confirm: 确认标志,必须为true才能执行 // 注意:此操作会回滚所有已应用的迁移,然后清空迁移记录 func (m *Migrator) ResetAll(confirm bool) error { if !confirm { return fmt.Errorf("reset all operation requires explicit confirmation (confirm must be true)") } if err := m.initTable(); err != nil { return err } // 获取已应用的迁移 applied, err := m.getAppliedMigrations() if err != nil { return err } count := len(applied) if count == 0 { fmt.Println("No migrations to reset") return nil } // 回滚所有已应用的迁移 fmt.Printf("Rolling back %d migration(s)...\n", count) // 排序迁移(倒序) sort.Slice(m.migrations, func(i, j int) bool { return m.migrations[i].Version > m.migrations[j].Version }) // 回滚所有已应用的迁移 for _, migration := range m.migrations { if !applied[migration.Version] { continue } if migration.Down == nil { fmt.Printf("Warning: Migration %s has no Down function, skipping rollback\n", migration.Version) continue } // 开始事务 tx := m.db.Begin() if tx.Error != nil { return fmt.Errorf("failed to begin transaction: %w", tx.Error) } // 执行回滚 if err := migration.Down(tx); err != nil { tx.Rollback() return fmt.Errorf("failed to rollback migration %s: %w", migration.Version, err) } // 删除迁移记录 if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, false); err != nil { tx.Rollback() return err } // 提交事务 if err := tx.Commit().Error; err != nil { return fmt.Errorf("failed to commit rollback %s: %w", migration.Version, err) } fmt.Printf("Rolled back migration: %s - %s\n", migration.Version, migration.Description) } fmt.Printf("Reset all completed: %d migration(s) rolled back and records cleared\n", count) return nil } // ResetAllWithConfirm 交互式重置所有迁移并回滚(带确认提示) // 会提示用户数据会被清空,需要输入确认才能执行 func (m *Migrator) ResetAllWithConfirm() error { if err := m.initTable(); err != nil { return err } // 获取已应用的迁移数量 applied, err := m.getAppliedMigrations() if err != nil { return err } count := len(applied) if count == 0 { fmt.Println("No migrations to reset") return nil } // 显示警告信息 fmt.Println("=" + strings.Repeat("=", 60) + "=") fmt.Println("WARNING: This operation will rollback ALL migrations and clear records!") fmt.Println("=" + strings.Repeat("=", 60) + "=") fmt.Printf("This will rollback %d migration(s) and clear all migration records.\n", count) fmt.Println() fmt.Println("IMPORTANT NOTES:") fmt.Println(" 1. This operation will EXECUTE Down() functions for all applied migrations.") fmt.Println(" 2. All database changes made by migrations will be REVERTED.") fmt.Println(" 3. This operation CANNOT be undone!") fmt.Println(" 4. Make sure you have a database backup before proceeding.") fmt.Println() fmt.Print("Type 'RESET ALL' (all caps) to confirm, or anything else to cancel: ") // 读取用户输入 reader := bufio.NewReader(os.Stdin) input, err := reader.ReadString('\n') if err != nil { return fmt.Errorf("failed to read confirmation: %w", err) } // 去除换行符和空格 input = strings.TrimSpace(input) // 检查确认 if input != "RESET ALL" { fmt.Println("Reset all operation cancelled.") return nil } // 执行重置 return m.ResetAll(true) }