package migration import ( "fmt" "os" "path/filepath" "sort" "strconv" "strings" "time" "gorm.io/gorm" ) // Migration 表示一个数据库迁移 type Migration struct { Version string Description string Up func(*gorm.DB) error Down func(*gorm.DB) error } // Migrator 数据库迁移器 type Migrator struct { db *gorm.DB migrations []Migration tableName string } // NewMigrator 创建新的迁移器 // db: GORM数据库连接 // tableName: 存储迁移记录的表名,默认为 "schema_migrations" func NewMigrator(db *gorm.DB, tableName ...string) *Migrator { table := "schema_migrations" if len(tableName) > 0 && tableName[0] != "" { table = tableName[0] } return &Migrator{ db: db, migrations: make([]Migration, 0), tableName: table, } } // AddMigration 添加迁移 func (m *Migrator) AddMigration(migration Migration) { m.migrations = append(m.migrations, migration) } // AddMigrations 批量添加迁移 func (m *Migrator) AddMigrations(migrations ...Migration) { m.migrations = append(m.migrations, migrations...) } // initTable 初始化迁移记录表 func (m *Migrator) initTable() error { // 检查表是否存在 var exists bool err := m.db.Raw(fmt.Sprintf(` SELECT EXISTS ( SELECT FROM information_schema.tables WHERE table_schema = CURRENT_SCHEMA() AND table_name = '%s' ) `, m.tableName)).Scan(&exists).Error if err != nil { // 如果查询失败,可能是SQLite或其他数据库,尝试直接创建 exists = false } if !exists { // 创建迁移记录表 err = m.db.Exec(fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( version VARCHAR(255) PRIMARY KEY, description VARCHAR(255), applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) `, m.tableName)).Error if err != nil { return fmt.Errorf("failed to create migration table: %w", err) } } return nil } // getAppliedMigrations 获取已应用的迁移版本 func (m *Migrator) getAppliedMigrations() (map[string]bool, error) { if err := m.initTable(); err != nil { return nil, err } var versions []string err := m.db.Table(m.tableName).Select("version").Pluck("version", &versions).Error if err != nil { return nil, fmt.Errorf("failed to get applied migrations: %w", err) } applied := make(map[string]bool) for _, v := range versions { applied[v] = true } return applied, nil } // recordMigration 记录迁移 func (m *Migrator) recordMigration(version, description string, isUp bool) error { if err := m.initTable(); err != nil { return err } if isUp { // 记录迁移 err := m.db.Exec(fmt.Sprintf(` INSERT INTO %s (version, description, applied_at) VALUES (?, ?, ?) `, m.tableName), version, description, time.Now()).Error if err != nil { return fmt.Errorf("failed to record migration: %w", err) } } else { // 删除迁移记录 err := m.db.Exec(fmt.Sprintf("DELETE FROM %s WHERE version = ?", m.tableName), version).Error if err != nil { return fmt.Errorf("failed to remove migration record: %w", err) } } return nil } // Up 执行所有未应用的迁移 func (m *Migrator) Up() error { if err := m.initTable(); err != nil { return err } applied, err := m.getAppliedMigrations() if err != nil { return err } // 排序迁移 sort.Slice(m.migrations, func(i, j int) bool { return m.migrations[i].Version < m.migrations[j].Version }) // 执行未应用的迁移 for _, migration := range m.migrations { if applied[migration.Version] { continue } if migration.Up == nil { return fmt.Errorf("migration %s has no Up function", migration.Version) } // 开始事务 tx := m.db.Begin() if tx.Error != nil { return fmt.Errorf("failed to begin transaction: %w", tx.Error) } // 执行迁移 if err := migration.Up(tx); err != nil { tx.Rollback() return fmt.Errorf("failed to apply migration %s: %w", migration.Version, err) } // 记录迁移 if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, true); err != nil { tx.Rollback() return err } // 提交事务 if err := tx.Commit().Error; err != nil { return fmt.Errorf("failed to commit migration %s: %w", migration.Version, err) } fmt.Printf("Applied migration: %s - %s\n", migration.Version, migration.Description) } return nil } // Down 回滚最后一个迁移 func (m *Migrator) Down() error { if err := m.initTable(); err != nil { return err } applied, err := m.getAppliedMigrations() if err != nil { return err } // 排序迁移(倒序) sort.Slice(m.migrations, func(i, j int) bool { return m.migrations[i].Version > m.migrations[j].Version }) // 找到最后一个已应用的迁移并回滚 for _, migration := range m.migrations { if !applied[migration.Version] { continue } if migration.Down == nil { return fmt.Errorf("migration %s has no Down function", migration.Version) } // 开始事务 tx := m.db.Begin() if tx.Error != nil { return fmt.Errorf("failed to begin transaction: %w", tx.Error) } // 执行回滚 if err := migration.Down(tx); err != nil { tx.Rollback() return fmt.Errorf("failed to rollback migration %s: %w", migration.Version, err) } // 删除迁移记录 if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, false); err != nil { tx.Rollback() return err } // 提交事务 if err := tx.Commit().Error; err != nil { return fmt.Errorf("failed to commit rollback %s: %w", migration.Version, err) } fmt.Printf("Rolled back migration: %s - %s\n", migration.Version, migration.Description) return nil } return fmt.Errorf("no migrations to rollback") } // recordMigrationWithDB 使用指定的数据库连接记录迁移 func (m *Migrator) recordMigrationWithDB(db *gorm.DB, version, description string, isUp bool) error { if isUp { err := db.Exec(fmt.Sprintf(` INSERT INTO %s (version, description, applied_at) VALUES (?, ?, ?) `, m.tableName), version, description, time.Now()).Error if err != nil { return fmt.Errorf("failed to record migration: %w", err) } } else { err := db.Exec(fmt.Sprintf("DELETE FROM %s WHERE version = ?", m.tableName), version).Error if err != nil { return fmt.Errorf("failed to remove migration record: %w", err) } } return nil } // Status 查看迁移状态 func (m *Migrator) Status() ([]MigrationStatus, error) { if err := m.initTable(); err != nil { return nil, err } applied, err := m.getAppliedMigrations() if err != nil { return nil, err } // 排序迁移 sort.Slice(m.migrations, func(i, j int) bool { return m.migrations[i].Version < m.migrations[j].Version }) status := make([]MigrationStatus, 0, len(m.migrations)) for _, migration := range m.migrations { status = append(status, MigrationStatus{ Version: migration.Version, Description: migration.Description, Applied: applied[migration.Version], }) } return status, nil } // MigrationStatus 迁移状态 type MigrationStatus struct { Version string Description string Applied bool } // LoadMigrationsFromFiles 从文件系统加载迁移文件 // dir: 迁移文件目录 // pattern: 文件命名模式,例如 "*.sql" 或 "*.up.sql" // 文件命名格式: {version}_{description}.sql 或 {version}_{description}.up.sql func LoadMigrationsFromFiles(dir string, pattern string) ([]Migration, error) { files, err := filepath.Glob(filepath.Join(dir, pattern)) if err != nil { return nil, fmt.Errorf("failed to glob migration files: %w", err) } migrations := make([]Migration, 0) for _, file := range files { baseName := filepath.Base(file) // 移除扩展名 nameWithoutExt := strings.TrimSuffix(baseName, filepath.Ext(baseName)) // 移除 .up 后缀(如果存在) nameWithoutExt = strings.TrimSuffix(nameWithoutExt, ".up") // 解析版本号和描述 parts := strings.SplitN(nameWithoutExt, "_", 2) if len(parts) < 2 { return nil, fmt.Errorf("invalid migration file name format: %s (expected: {version}_{description})", baseName) } version := parts[0] description := strings.Join(parts[1:], "_") // 读取文件内容 content, err := os.ReadFile(file) if err != nil { return nil, fmt.Errorf("failed to read migration file %s: %w", file, err) } sqlContent := string(content) // 查找对应的 down 文件 downFile := strings.Replace(file, ".up.sql", ".down.sql", 1) downFile = strings.Replace(downFile, ".sql", ".down.sql", 1) var downSQL string if downContent, err := os.ReadFile(downFile); err == nil { downSQL = string(downContent) } migration := Migration{ Version: version, Description: description, Up: func(db *gorm.DB) error { return db.Exec(sqlContent).Error }, } if downSQL != "" { migration.Down = func(db *gorm.DB) error { return db.Exec(downSQL).Error } } migrations = append(migrations, migration) } // 按版本号排序 sort.Slice(migrations, func(i, j int) bool { return migrations[i].Version < migrations[j].Version }) return migrations, nil } // GenerateVersion 生成迁移版本号(基于时间戳) func GenerateVersion() string { return strconv.FormatInt(time.Now().Unix(), 10) }