Files
go-common/migration/migration.go

374 lines
9.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package migration
import (
"fmt"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"time"
"gorm.io/gorm"
)
// Migration 表示一个数据库迁移
type Migration struct {
Version string
Description string
Up func(*gorm.DB) error
Down func(*gorm.DB) error
}
// Migrator 数据库迁移器
type Migrator struct {
db *gorm.DB
migrations []Migration
tableName string
}
// NewMigrator 创建新的迁移器
// db: GORM数据库连接
// tableName: 存储迁移记录的表名,默认为 "schema_migrations"
func NewMigrator(db *gorm.DB, tableName ...string) *Migrator {
table := "schema_migrations"
if len(tableName) > 0 && tableName[0] != "" {
table = tableName[0]
}
return &Migrator{
db: db,
migrations: make([]Migration, 0),
tableName: table,
}
}
// AddMigration 添加迁移
func (m *Migrator) AddMigration(migration Migration) {
m.migrations = append(m.migrations, migration)
}
// AddMigrations 批量添加迁移
func (m *Migrator) AddMigrations(migrations ...Migration) {
m.migrations = append(m.migrations, migrations...)
}
// initTable 初始化迁移记录表
func (m *Migrator) initTable() error {
// 检查表是否存在
var exists bool
err := m.db.Raw(fmt.Sprintf(`
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_schema = CURRENT_SCHEMA()
AND table_name = '%s'
)
`, m.tableName)).Scan(&exists).Error
if err != nil {
// 如果查询失败可能是SQLite或其他数据库尝试直接创建
exists = false
}
if !exists {
// 创建迁移记录表
err = m.db.Exec(fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
version VARCHAR(255) PRIMARY KEY,
description VARCHAR(255),
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
`, m.tableName)).Error
if err != nil {
return fmt.Errorf("failed to create migration table: %w", err)
}
}
return nil
}
// getAppliedMigrations 获取已应用的迁移版本
func (m *Migrator) getAppliedMigrations() (map[string]bool, error) {
if err := m.initTable(); err != nil {
return nil, err
}
var versions []string
err := m.db.Table(m.tableName).Select("version").Pluck("version", &versions).Error
if err != nil {
return nil, fmt.Errorf("failed to get applied migrations: %w", err)
}
applied := make(map[string]bool)
for _, v := range versions {
applied[v] = true
}
return applied, nil
}
// recordMigration 记录迁移
func (m *Migrator) recordMigration(version, description string, isUp bool) error {
if err := m.initTable(); err != nil {
return err
}
if isUp {
// 记录迁移
err := m.db.Exec(fmt.Sprintf(`
INSERT INTO %s (version, description, applied_at)
VALUES (?, ?, ?)
`, m.tableName), version, description, time.Now()).Error
if err != nil {
return fmt.Errorf("failed to record migration: %w", err)
}
} else {
// 删除迁移记录
err := m.db.Exec(fmt.Sprintf("DELETE FROM %s WHERE version = ?", m.tableName), version).Error
if err != nil {
return fmt.Errorf("failed to remove migration record: %w", err)
}
}
return nil
}
// Up 执行所有未应用的迁移
func (m *Migrator) Up() error {
if err := m.initTable(); err != nil {
return err
}
applied, err := m.getAppliedMigrations()
if err != nil {
return err
}
// 排序迁移
sort.Slice(m.migrations, func(i, j int) bool {
return m.migrations[i].Version < m.migrations[j].Version
})
// 执行未应用的迁移
for _, migration := range m.migrations {
if applied[migration.Version] {
continue
}
if migration.Up == nil {
return fmt.Errorf("migration %s has no Up function", migration.Version)
}
// 开始事务
tx := m.db.Begin()
if tx.Error != nil {
return fmt.Errorf("failed to begin transaction: %w", tx.Error)
}
// 执行迁移
if err := migration.Up(tx); err != nil {
tx.Rollback()
return fmt.Errorf("failed to apply migration %s: %w", migration.Version, err)
}
// 记录迁移
if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, true); err != nil {
tx.Rollback()
return err
}
// 提交事务
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("failed to commit migration %s: %w", migration.Version, err)
}
fmt.Printf("Applied migration: %s - %s\n", migration.Version, migration.Description)
}
return nil
}
// Down 回滚最后一个迁移
func (m *Migrator) Down() error {
if err := m.initTable(); err != nil {
return err
}
applied, err := m.getAppliedMigrations()
if err != nil {
return err
}
// 排序迁移(倒序)
sort.Slice(m.migrations, func(i, j int) bool {
return m.migrations[i].Version > m.migrations[j].Version
})
// 找到最后一个已应用的迁移并回滚
for _, migration := range m.migrations {
if !applied[migration.Version] {
continue
}
if migration.Down == nil {
return fmt.Errorf("migration %s has no Down function", migration.Version)
}
// 开始事务
tx := m.db.Begin()
if tx.Error != nil {
return fmt.Errorf("failed to begin transaction: %w", tx.Error)
}
// 执行回滚
if err := migration.Down(tx); err != nil {
tx.Rollback()
return fmt.Errorf("failed to rollback migration %s: %w", migration.Version, err)
}
// 删除迁移记录
if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, false); err != nil {
tx.Rollback()
return err
}
// 提交事务
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("failed to commit rollback %s: %w", migration.Version, err)
}
fmt.Printf("Rolled back migration: %s - %s\n", migration.Version, migration.Description)
return nil
}
return fmt.Errorf("no migrations to rollback")
}
// recordMigrationWithDB 使用指定的数据库连接记录迁移
func (m *Migrator) recordMigrationWithDB(db *gorm.DB, version, description string, isUp bool) error {
if isUp {
err := db.Exec(fmt.Sprintf(`
INSERT INTO %s (version, description, applied_at)
VALUES (?, ?, ?)
`, m.tableName), version, description, time.Now()).Error
if err != nil {
return fmt.Errorf("failed to record migration: %w", err)
}
} else {
err := db.Exec(fmt.Sprintf("DELETE FROM %s WHERE version = ?", m.tableName), version).Error
if err != nil {
return fmt.Errorf("failed to remove migration record: %w", err)
}
}
return nil
}
// Status 查看迁移状态
func (m *Migrator) Status() ([]MigrationStatus, error) {
if err := m.initTable(); err != nil {
return nil, err
}
applied, err := m.getAppliedMigrations()
if err != nil {
return nil, err
}
// 排序迁移
sort.Slice(m.migrations, func(i, j int) bool {
return m.migrations[i].Version < m.migrations[j].Version
})
status := make([]MigrationStatus, 0, len(m.migrations))
for _, migration := range m.migrations {
status = append(status, MigrationStatus{
Version: migration.Version,
Description: migration.Description,
Applied: applied[migration.Version],
})
}
return status, nil
}
// MigrationStatus 迁移状态
type MigrationStatus struct {
Version string
Description string
Applied bool
}
// LoadMigrationsFromFiles 从文件系统加载迁移文件
// dir: 迁移文件目录
// pattern: 文件命名模式,例如 "*.sql" 或 "*.up.sql"
// 文件命名格式: {version}_{description}.sql 或 {version}_{description}.up.sql
func LoadMigrationsFromFiles(dir string, pattern string) ([]Migration, error) {
files, err := filepath.Glob(filepath.Join(dir, pattern))
if err != nil {
return nil, fmt.Errorf("failed to glob migration files: %w", err)
}
migrations := make([]Migration, 0)
for _, file := range files {
baseName := filepath.Base(file)
// 移除扩展名
nameWithoutExt := strings.TrimSuffix(baseName, filepath.Ext(baseName))
// 移除 .up 后缀(如果存在)
nameWithoutExt = strings.TrimSuffix(nameWithoutExt, ".up")
// 解析版本号和描述
parts := strings.SplitN(nameWithoutExt, "_", 2)
if len(parts) < 2 {
return nil, fmt.Errorf("invalid migration file name format: %s (expected: {version}_{description})", baseName)
}
version := parts[0]
description := strings.Join(parts[1:], "_")
// 读取文件内容
content, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("failed to read migration file %s: %w", file, err)
}
sqlContent := string(content)
// 查找对应的 down 文件
downFile := strings.Replace(file, ".up.sql", ".down.sql", 1)
downFile = strings.Replace(downFile, ".sql", ".down.sql", 1)
var downSQL string
if downContent, err := os.ReadFile(downFile); err == nil {
downSQL = string(downContent)
}
migration := Migration{
Version: version,
Description: description,
Up: func(db *gorm.DB) error {
return db.Exec(sqlContent).Error
},
}
if downSQL != "" {
migration.Down = func(db *gorm.DB) error {
return db.Exec(downSQL).Error
}
}
migrations = append(migrations, migration)
}
// 按版本号排序
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].Version < migrations[j].Version
})
return migrations, nil
}
// GenerateVersion 生成迁移版本号(基于时间戳)
func GenerateVersion() string {
return strconv.FormatInt(time.Now().Unix(), 10)
}