初始版本,工具基础类
This commit is contained in:
373
migration/migration.go
Normal file
373
migration/migration.go
Normal file
@@ -0,0 +1,373 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Migration 表示一个数据库迁移
|
||||
type Migration struct {
|
||||
Version string
|
||||
Description string
|
||||
Up func(*gorm.DB) error
|
||||
Down func(*gorm.DB) error
|
||||
}
|
||||
|
||||
// Migrator 数据库迁移器
|
||||
type Migrator struct {
|
||||
db *gorm.DB
|
||||
migrations []Migration
|
||||
tableName string
|
||||
}
|
||||
|
||||
// NewMigrator 创建新的迁移器
|
||||
// db: GORM数据库连接
|
||||
// tableName: 存储迁移记录的表名,默认为 "schema_migrations"
|
||||
func NewMigrator(db *gorm.DB, tableName ...string) *Migrator {
|
||||
table := "schema_migrations"
|
||||
if len(tableName) > 0 && tableName[0] != "" {
|
||||
table = tableName[0]
|
||||
}
|
||||
|
||||
return &Migrator{
|
||||
db: db,
|
||||
migrations: make([]Migration, 0),
|
||||
tableName: table,
|
||||
}
|
||||
}
|
||||
|
||||
// AddMigration 添加迁移
|
||||
func (m *Migrator) AddMigration(migration Migration) {
|
||||
m.migrations = append(m.migrations, migration)
|
||||
}
|
||||
|
||||
// AddMigrations 批量添加迁移
|
||||
func (m *Migrator) AddMigrations(migrations ...Migration) {
|
||||
m.migrations = append(m.migrations, migrations...)
|
||||
}
|
||||
|
||||
// initTable 初始化迁移记录表
|
||||
func (m *Migrator) initTable() error {
|
||||
// 检查表是否存在
|
||||
var exists bool
|
||||
err := m.db.Raw(fmt.Sprintf(`
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_schema = CURRENT_SCHEMA()
|
||||
AND table_name = '%s'
|
||||
)
|
||||
`, m.tableName)).Scan(&exists).Error
|
||||
|
||||
if err != nil {
|
||||
// 如果查询失败,可能是SQLite或其他数据库,尝试直接创建
|
||||
exists = false
|
||||
}
|
||||
|
||||
if !exists {
|
||||
// 创建迁移记录表
|
||||
err = m.db.Exec(fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s (
|
||||
version VARCHAR(255) PRIMARY KEY,
|
||||
description VARCHAR(255),
|
||||
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
`, m.tableName)).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migration table: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAppliedMigrations 获取已应用的迁移版本
|
||||
func (m *Migrator) getAppliedMigrations() (map[string]bool, error) {
|
||||
if err := m.initTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var versions []string
|
||||
err := m.db.Table(m.tableName).Select("version").Pluck("version", &versions).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get applied migrations: %w", err)
|
||||
}
|
||||
|
||||
applied := make(map[string]bool)
|
||||
for _, v := range versions {
|
||||
applied[v] = true
|
||||
}
|
||||
|
||||
return applied, nil
|
||||
}
|
||||
|
||||
// recordMigration 记录迁移
|
||||
func (m *Migrator) recordMigration(version, description string, isUp bool) error {
|
||||
if err := m.initTable(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isUp {
|
||||
// 记录迁移
|
||||
err := m.db.Exec(fmt.Sprintf(`
|
||||
INSERT INTO %s (version, description, applied_at)
|
||||
VALUES (?, ?, ?)
|
||||
`, m.tableName), version, description, time.Now()).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to record migration: %w", err)
|
||||
}
|
||||
} else {
|
||||
// 删除迁移记录
|
||||
err := m.db.Exec(fmt.Sprintf("DELETE FROM %s WHERE version = ?", m.tableName), version).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove migration record: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Up 执行所有未应用的迁移
|
||||
func (m *Migrator) Up() error {
|
||||
if err := m.initTable(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
applied, err := m.getAppliedMigrations()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 排序迁移
|
||||
sort.Slice(m.migrations, func(i, j int) bool {
|
||||
return m.migrations[i].Version < m.migrations[j].Version
|
||||
})
|
||||
|
||||
// 执行未应用的迁移
|
||||
for _, migration := range m.migrations {
|
||||
if applied[migration.Version] {
|
||||
continue
|
||||
}
|
||||
|
||||
if migration.Up == nil {
|
||||
return fmt.Errorf("migration %s has no Up function", migration.Version)
|
||||
}
|
||||
|
||||
// 开始事务
|
||||
tx := m.db.Begin()
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", tx.Error)
|
||||
}
|
||||
|
||||
// 执行迁移
|
||||
if err := migration.Up(tx); err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to apply migration %s: %w", migration.Version, err)
|
||||
}
|
||||
|
||||
// 记录迁移
|
||||
if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, true); err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
// 提交事务
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return fmt.Errorf("failed to commit migration %s: %w", migration.Version, err)
|
||||
}
|
||||
|
||||
fmt.Printf("Applied migration: %s - %s\n", migration.Version, migration.Description)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Down 回滚最后一个迁移
|
||||
func (m *Migrator) Down() error {
|
||||
if err := m.initTable(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
applied, err := m.getAppliedMigrations()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 排序迁移(倒序)
|
||||
sort.Slice(m.migrations, func(i, j int) bool {
|
||||
return m.migrations[i].Version > m.migrations[j].Version
|
||||
})
|
||||
|
||||
// 找到最后一个已应用的迁移并回滚
|
||||
for _, migration := range m.migrations {
|
||||
if !applied[migration.Version] {
|
||||
continue
|
||||
}
|
||||
|
||||
if migration.Down == nil {
|
||||
return fmt.Errorf("migration %s has no Down function", migration.Version)
|
||||
}
|
||||
|
||||
// 开始事务
|
||||
tx := m.db.Begin()
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", tx.Error)
|
||||
}
|
||||
|
||||
// 执行回滚
|
||||
if err := migration.Down(tx); err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to rollback migration %s: %w", migration.Version, err)
|
||||
}
|
||||
|
||||
// 删除迁移记录
|
||||
if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, false); err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
// 提交事务
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return fmt.Errorf("failed to commit rollback %s: %w", migration.Version, err)
|
||||
}
|
||||
|
||||
fmt.Printf("Rolled back migration: %s - %s\n", migration.Version, migration.Description)
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("no migrations to rollback")
|
||||
}
|
||||
|
||||
// recordMigrationWithDB 使用指定的数据库连接记录迁移
|
||||
func (m *Migrator) recordMigrationWithDB(db *gorm.DB, version, description string, isUp bool) error {
|
||||
if isUp {
|
||||
err := db.Exec(fmt.Sprintf(`
|
||||
INSERT INTO %s (version, description, applied_at)
|
||||
VALUES (?, ?, ?)
|
||||
`, m.tableName), version, description, time.Now()).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to record migration: %w", err)
|
||||
}
|
||||
} else {
|
||||
err := db.Exec(fmt.Sprintf("DELETE FROM %s WHERE version = ?", m.tableName), version).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove migration record: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Status 查看迁移状态
|
||||
func (m *Migrator) Status() ([]MigrationStatus, error) {
|
||||
if err := m.initTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
applied, err := m.getAppliedMigrations()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 排序迁移
|
||||
sort.Slice(m.migrations, func(i, j int) bool {
|
||||
return m.migrations[i].Version < m.migrations[j].Version
|
||||
})
|
||||
|
||||
status := make([]MigrationStatus, 0, len(m.migrations))
|
||||
for _, migration := range m.migrations {
|
||||
status = append(status, MigrationStatus{
|
||||
Version: migration.Version,
|
||||
Description: migration.Description,
|
||||
Applied: applied[migration.Version],
|
||||
})
|
||||
}
|
||||
|
||||
return status, nil
|
||||
}
|
||||
|
||||
// MigrationStatus 迁移状态
|
||||
type MigrationStatus struct {
|
||||
Version string
|
||||
Description string
|
||||
Applied bool
|
||||
}
|
||||
|
||||
// LoadMigrationsFromFiles 从文件系统加载迁移文件
|
||||
// dir: 迁移文件目录
|
||||
// pattern: 文件命名模式,例如 "*.sql" 或 "*.up.sql"
|
||||
// 文件命名格式: {version}_{description}.sql 或 {version}_{description}.up.sql
|
||||
func LoadMigrationsFromFiles(dir string, pattern string) ([]Migration, error) {
|
||||
files, err := filepath.Glob(filepath.Join(dir, pattern))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to glob migration files: %w", err)
|
||||
}
|
||||
|
||||
migrations := make([]Migration, 0)
|
||||
for _, file := range files {
|
||||
baseName := filepath.Base(file)
|
||||
// 移除扩展名
|
||||
nameWithoutExt := strings.TrimSuffix(baseName, filepath.Ext(baseName))
|
||||
// 移除 .up 后缀(如果存在)
|
||||
nameWithoutExt = strings.TrimSuffix(nameWithoutExt, ".up")
|
||||
|
||||
// 解析版本号和描述
|
||||
parts := strings.SplitN(nameWithoutExt, "_", 2)
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("invalid migration file name format: %s (expected: {version}_{description})", baseName)
|
||||
}
|
||||
|
||||
version := parts[0]
|
||||
description := strings.Join(parts[1:], "_")
|
||||
|
||||
// 读取文件内容
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read migration file %s: %w", file, err)
|
||||
}
|
||||
|
||||
sqlContent := string(content)
|
||||
|
||||
// 查找对应的 down 文件
|
||||
downFile := strings.Replace(file, ".up.sql", ".down.sql", 1)
|
||||
downFile = strings.Replace(downFile, ".sql", ".down.sql", 1)
|
||||
var downSQL string
|
||||
if downContent, err := os.ReadFile(downFile); err == nil {
|
||||
downSQL = string(downContent)
|
||||
}
|
||||
|
||||
migration := Migration{
|
||||
Version: version,
|
||||
Description: description,
|
||||
Up: func(db *gorm.DB) error {
|
||||
return db.Exec(sqlContent).Error
|
||||
},
|
||||
}
|
||||
|
||||
if downSQL != "" {
|
||||
migration.Down = func(db *gorm.DB) error {
|
||||
return db.Exec(downSQL).Error
|
||||
}
|
||||
}
|
||||
|
||||
migrations = append(migrations, migration)
|
||||
}
|
||||
|
||||
// 按版本号排序
|
||||
sort.Slice(migrations, func(i, j int) bool {
|
||||
return migrations[i].Version < migrations[j].Version
|
||||
})
|
||||
|
||||
return migrations, nil
|
||||
}
|
||||
|
||||
// GenerateVersion 生成迁移版本号(基于时间戳)
|
||||
func GenerateVersion() string {
|
||||
return strconv.FormatInt(time.Now().Unix(), 10)
|
||||
}
|
||||
Reference in New Issue
Block a user