374 lines
9.1 KiB
Go
374 lines
9.1 KiB
Go
package migration
|
||
|
||
import (
|
||
"fmt"
|
||
"os"
|
||
"path/filepath"
|
||
"sort"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// Migration 表示一个数据库迁移
|
||
type Migration struct {
|
||
Version string
|
||
Description string
|
||
Up func(*gorm.DB) error
|
||
Down func(*gorm.DB) error
|
||
}
|
||
|
||
// Migrator 数据库迁移器
|
||
type Migrator struct {
|
||
db *gorm.DB
|
||
migrations []Migration
|
||
tableName string
|
||
}
|
||
|
||
// NewMigrator 创建新的迁移器
|
||
// db: GORM数据库连接
|
||
// tableName: 存储迁移记录的表名,默认为 "schema_migrations"
|
||
func NewMigrator(db *gorm.DB, tableName ...string) *Migrator {
|
||
table := "schema_migrations"
|
||
if len(tableName) > 0 && tableName[0] != "" {
|
||
table = tableName[0]
|
||
}
|
||
|
||
return &Migrator{
|
||
db: db,
|
||
migrations: make([]Migration, 0),
|
||
tableName: table,
|
||
}
|
||
}
|
||
|
||
// AddMigration 添加迁移
|
||
func (m *Migrator) AddMigration(migration Migration) {
|
||
m.migrations = append(m.migrations, migration)
|
||
}
|
||
|
||
// AddMigrations 批量添加迁移
|
||
func (m *Migrator) AddMigrations(migrations ...Migration) {
|
||
m.migrations = append(m.migrations, migrations...)
|
||
}
|
||
|
||
// initTable 初始化迁移记录表
|
||
func (m *Migrator) initTable() error {
|
||
// 检查表是否存在
|
||
var exists bool
|
||
err := m.db.Raw(fmt.Sprintf(`
|
||
SELECT EXISTS (
|
||
SELECT FROM information_schema.tables
|
||
WHERE table_schema = CURRENT_SCHEMA()
|
||
AND table_name = '%s'
|
||
)
|
||
`, m.tableName)).Scan(&exists).Error
|
||
|
||
if err != nil {
|
||
// 如果查询失败,可能是SQLite或其他数据库,尝试直接创建
|
||
exists = false
|
||
}
|
||
|
||
if !exists {
|
||
// 创建迁移记录表
|
||
err = m.db.Exec(fmt.Sprintf(`
|
||
CREATE TABLE IF NOT EXISTS %s (
|
||
version VARCHAR(255) PRIMARY KEY,
|
||
description VARCHAR(255),
|
||
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||
)
|
||
`, m.tableName)).Error
|
||
if err != nil {
|
||
return fmt.Errorf("failed to create migration table: %w", err)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// getAppliedMigrations 获取已应用的迁移版本
|
||
func (m *Migrator) getAppliedMigrations() (map[string]bool, error) {
|
||
if err := m.initTable(); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var versions []string
|
||
err := m.db.Table(m.tableName).Select("version").Pluck("version", &versions).Error
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get applied migrations: %w", err)
|
||
}
|
||
|
||
applied := make(map[string]bool)
|
||
for _, v := range versions {
|
||
applied[v] = true
|
||
}
|
||
|
||
return applied, nil
|
||
}
|
||
|
||
// recordMigration 记录迁移
|
||
func (m *Migrator) recordMigration(version, description string, isUp bool) error {
|
||
if err := m.initTable(); err != nil {
|
||
return err
|
||
}
|
||
|
||
if isUp {
|
||
// 记录迁移
|
||
err := m.db.Exec(fmt.Sprintf(`
|
||
INSERT INTO %s (version, description, applied_at)
|
||
VALUES (?, ?, ?)
|
||
`, m.tableName), version, description, time.Now()).Error
|
||
if err != nil {
|
||
return fmt.Errorf("failed to record migration: %w", err)
|
||
}
|
||
} else {
|
||
// 删除迁移记录
|
||
err := m.db.Exec(fmt.Sprintf("DELETE FROM %s WHERE version = ?", m.tableName), version).Error
|
||
if err != nil {
|
||
return fmt.Errorf("failed to remove migration record: %w", err)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// Up 执行所有未应用的迁移
|
||
func (m *Migrator) Up() error {
|
||
if err := m.initTable(); err != nil {
|
||
return err
|
||
}
|
||
|
||
applied, err := m.getAppliedMigrations()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 排序迁移
|
||
sort.Slice(m.migrations, func(i, j int) bool {
|
||
return m.migrations[i].Version < m.migrations[j].Version
|
||
})
|
||
|
||
// 执行未应用的迁移
|
||
for _, migration := range m.migrations {
|
||
if applied[migration.Version] {
|
||
continue
|
||
}
|
||
|
||
if migration.Up == nil {
|
||
return fmt.Errorf("migration %s has no Up function", migration.Version)
|
||
}
|
||
|
||
// 开始事务
|
||
tx := m.db.Begin()
|
||
if tx.Error != nil {
|
||
return fmt.Errorf("failed to begin transaction: %w", tx.Error)
|
||
}
|
||
|
||
// 执行迁移
|
||
if err := migration.Up(tx); err != nil {
|
||
tx.Rollback()
|
||
return fmt.Errorf("failed to apply migration %s: %w", migration.Version, err)
|
||
}
|
||
|
||
// 记录迁移
|
||
if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, true); err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
|
||
// 提交事务
|
||
if err := tx.Commit().Error; err != nil {
|
||
return fmt.Errorf("failed to commit migration %s: %w", migration.Version, err)
|
||
}
|
||
|
||
fmt.Printf("Applied migration: %s - %s\n", migration.Version, migration.Description)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// Down 回滚最后一个迁移
|
||
func (m *Migrator) Down() error {
|
||
if err := m.initTable(); err != nil {
|
||
return err
|
||
}
|
||
|
||
applied, err := m.getAppliedMigrations()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 排序迁移(倒序)
|
||
sort.Slice(m.migrations, func(i, j int) bool {
|
||
return m.migrations[i].Version > m.migrations[j].Version
|
||
})
|
||
|
||
// 找到最后一个已应用的迁移并回滚
|
||
for _, migration := range m.migrations {
|
||
if !applied[migration.Version] {
|
||
continue
|
||
}
|
||
|
||
if migration.Down == nil {
|
||
return fmt.Errorf("migration %s has no Down function", migration.Version)
|
||
}
|
||
|
||
// 开始事务
|
||
tx := m.db.Begin()
|
||
if tx.Error != nil {
|
||
return fmt.Errorf("failed to begin transaction: %w", tx.Error)
|
||
}
|
||
|
||
// 执行回滚
|
||
if err := migration.Down(tx); err != nil {
|
||
tx.Rollback()
|
||
return fmt.Errorf("failed to rollback migration %s: %w", migration.Version, err)
|
||
}
|
||
|
||
// 删除迁移记录
|
||
if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, false); err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
|
||
// 提交事务
|
||
if err := tx.Commit().Error; err != nil {
|
||
return fmt.Errorf("failed to commit rollback %s: %w", migration.Version, err)
|
||
}
|
||
|
||
fmt.Printf("Rolled back migration: %s - %s\n", migration.Version, migration.Description)
|
||
return nil
|
||
}
|
||
|
||
return fmt.Errorf("no migrations to rollback")
|
||
}
|
||
|
||
// recordMigrationWithDB 使用指定的数据库连接记录迁移
|
||
func (m *Migrator) recordMigrationWithDB(db *gorm.DB, version, description string, isUp bool) error {
|
||
if isUp {
|
||
err := db.Exec(fmt.Sprintf(`
|
||
INSERT INTO %s (version, description, applied_at)
|
||
VALUES (?, ?, ?)
|
||
`, m.tableName), version, description, time.Now()).Error
|
||
if err != nil {
|
||
return fmt.Errorf("failed to record migration: %w", err)
|
||
}
|
||
} else {
|
||
err := db.Exec(fmt.Sprintf("DELETE FROM %s WHERE version = ?", m.tableName), version).Error
|
||
if err != nil {
|
||
return fmt.Errorf("failed to remove migration record: %w", err)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// Status 查看迁移状态
|
||
func (m *Migrator) Status() ([]MigrationStatus, error) {
|
||
if err := m.initTable(); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
applied, err := m.getAppliedMigrations()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 排序迁移
|
||
sort.Slice(m.migrations, func(i, j int) bool {
|
||
return m.migrations[i].Version < m.migrations[j].Version
|
||
})
|
||
|
||
status := make([]MigrationStatus, 0, len(m.migrations))
|
||
for _, migration := range m.migrations {
|
||
status = append(status, MigrationStatus{
|
||
Version: migration.Version,
|
||
Description: migration.Description,
|
||
Applied: applied[migration.Version],
|
||
})
|
||
}
|
||
|
||
return status, nil
|
||
}
|
||
|
||
// MigrationStatus 迁移状态
|
||
type MigrationStatus struct {
|
||
Version string
|
||
Description string
|
||
Applied bool
|
||
}
|
||
|
||
// LoadMigrationsFromFiles 从文件系统加载迁移文件
|
||
// dir: 迁移文件目录
|
||
// pattern: 文件命名模式,例如 "*.sql" 或 "*.up.sql"
|
||
// 文件命名格式: {version}_{description}.sql 或 {version}_{description}.up.sql
|
||
func LoadMigrationsFromFiles(dir string, pattern string) ([]Migration, error) {
|
||
files, err := filepath.Glob(filepath.Join(dir, pattern))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to glob migration files: %w", err)
|
||
}
|
||
|
||
migrations := make([]Migration, 0)
|
||
for _, file := range files {
|
||
baseName := filepath.Base(file)
|
||
// 移除扩展名
|
||
nameWithoutExt := strings.TrimSuffix(baseName, filepath.Ext(baseName))
|
||
// 移除 .up 后缀(如果存在)
|
||
nameWithoutExt = strings.TrimSuffix(nameWithoutExt, ".up")
|
||
|
||
// 解析版本号和描述
|
||
parts := strings.SplitN(nameWithoutExt, "_", 2)
|
||
if len(parts) < 2 {
|
||
return nil, fmt.Errorf("invalid migration file name format: %s (expected: {version}_{description})", baseName)
|
||
}
|
||
|
||
version := parts[0]
|
||
description := strings.Join(parts[1:], "_")
|
||
|
||
// 读取文件内容
|
||
content, err := os.ReadFile(file)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to read migration file %s: %w", file, err)
|
||
}
|
||
|
||
sqlContent := string(content)
|
||
|
||
// 查找对应的 down 文件
|
||
downFile := strings.Replace(file, ".up.sql", ".down.sql", 1)
|
||
downFile = strings.Replace(downFile, ".sql", ".down.sql", 1)
|
||
var downSQL string
|
||
if downContent, err := os.ReadFile(downFile); err == nil {
|
||
downSQL = string(downContent)
|
||
}
|
||
|
||
migration := Migration{
|
||
Version: version,
|
||
Description: description,
|
||
Up: func(db *gorm.DB) error {
|
||
return db.Exec(sqlContent).Error
|
||
},
|
||
}
|
||
|
||
if downSQL != "" {
|
||
migration.Down = func(db *gorm.DB) error {
|
||
return db.Exec(downSQL).Error
|
||
}
|
||
}
|
||
|
||
migrations = append(migrations, migration)
|
||
}
|
||
|
||
// 按版本号排序
|
||
sort.Slice(migrations, func(i, j int) bool {
|
||
return migrations[i].Version < migrations[j].Version
|
||
})
|
||
|
||
return migrations, nil
|
||
}
|
||
|
||
// GenerateVersion 生成迁移版本号(基于时间戳)
|
||
func GenerateVersion() string {
|
||
return strconv.FormatInt(time.Now().Unix(), 10)
|
||
}
|