909 lines
25 KiB
Go
909 lines
25 KiB
Go
package migration
|
||
|
||
import (
|
||
"bufio"
|
||
"fmt"
|
||
"os"
|
||
"path/filepath"
|
||
"sort"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// Migration 表示一个数据库迁移
|
||
type Migration struct {
|
||
Version string
|
||
Description string
|
||
Up func(*gorm.DB) error
|
||
Down func(*gorm.DB) error
|
||
}
|
||
|
||
// Migrator 数据库迁移器
|
||
type Migrator struct {
|
||
db *gorm.DB
|
||
migrations []Migration
|
||
tableName string
|
||
dbType string // 数据库类型: mysql, postgres, sqlite
|
||
}
|
||
|
||
// NewMigrator 创建新的迁移器
|
||
// db: GORM数据库连接
|
||
// tableName: 存储迁移记录的表名,默认为 "schema_migrations"
|
||
func NewMigrator(db *gorm.DB, tableName ...string) *Migrator {
|
||
table := "schema_migrations"
|
||
if len(tableName) > 0 && tableName[0] != "" {
|
||
table = tableName[0]
|
||
}
|
||
|
||
return &Migrator{
|
||
db: db,
|
||
migrations: make([]Migration, 0),
|
||
tableName: table,
|
||
dbType: "", // 未指定时为空,会使用兼容模式
|
||
}
|
||
}
|
||
|
||
// NewMigratorWithType 创建新的迁移器(指定数据库类型,性能更好)
|
||
// db: GORM数据库连接
|
||
// dbType: 数据库类型 ("mysql", "postgres", "sqlite")
|
||
// tableName: 存储迁移记录的表名,默认为 "schema_migrations"
|
||
func NewMigratorWithType(db *gorm.DB, dbType string, tableName ...string) *Migrator {
|
||
table := "schema_migrations"
|
||
if len(tableName) > 0 && tableName[0] != "" {
|
||
table = tableName[0]
|
||
}
|
||
|
||
return &Migrator{
|
||
db: db,
|
||
migrations: make([]Migration, 0),
|
||
tableName: table,
|
||
dbType: dbType,
|
||
}
|
||
}
|
||
|
||
// AddMigration 添加迁移
|
||
func (m *Migrator) AddMigration(migration Migration) {
|
||
m.migrations = append(m.migrations, migration)
|
||
}
|
||
|
||
// AddMigrations 批量添加迁移
|
||
func (m *Migrator) AddMigrations(migrations ...Migration) {
|
||
m.migrations = append(m.migrations, migrations...)
|
||
}
|
||
|
||
// initTable 初始化迁移记录表
|
||
func (m *Migrator) initTable() error {
|
||
// 检查表是否存在(根据数据库类型使用对应的SQL,性能更好)
|
||
var exists bool
|
||
var err error
|
||
|
||
switch m.dbType {
|
||
case "mysql":
|
||
// MySQL/MariaDB语法
|
||
var count int64
|
||
err = m.db.Raw(fmt.Sprintf(`
|
||
SELECT COUNT(*) FROM information_schema.tables
|
||
WHERE table_schema = DATABASE()
|
||
AND table_name = '%s'
|
||
`, m.tableName)).Scan(&count).Error
|
||
if err == nil {
|
||
exists = count > 0
|
||
}
|
||
case "postgres":
|
||
// PostgreSQL语法
|
||
err = m.db.Raw(fmt.Sprintf(`
|
||
SELECT EXISTS (
|
||
SELECT 1 FROM information_schema.tables
|
||
WHERE table_schema = CURRENT_SCHEMA()
|
||
AND table_name = '%s'
|
||
)
|
||
`, m.tableName)).Scan(&exists).Error
|
||
case "sqlite":
|
||
// SQLite语法
|
||
var count int64
|
||
err = m.db.Raw(fmt.Sprintf(`
|
||
SELECT COUNT(*) FROM sqlite_master
|
||
WHERE type='table' AND name='%s'
|
||
`, m.tableName)).Scan(&count).Error
|
||
if err == nil {
|
||
exists = count > 0
|
||
}
|
||
default:
|
||
// 未指定数据库类型时,使用兼容模式(向后兼容)
|
||
// 按顺序尝试不同数据库的语法
|
||
var count int64
|
||
err = m.db.Raw(fmt.Sprintf(`
|
||
SELECT COUNT(*) FROM information_schema.tables
|
||
WHERE table_schema = DATABASE()
|
||
AND table_name = '%s'
|
||
`, m.tableName)).Scan(&count).Error
|
||
if err == nil && count > 0 {
|
||
exists = true
|
||
} else {
|
||
var pgExists bool
|
||
err = m.db.Raw(fmt.Sprintf(`
|
||
SELECT EXISTS (
|
||
SELECT 1 FROM information_schema.tables
|
||
WHERE table_schema = CURRENT_SCHEMA()
|
||
AND table_name = '%s'
|
||
)
|
||
`, m.tableName)).Scan(&pgExists).Error
|
||
if err == nil {
|
||
exists = pgExists
|
||
} else {
|
||
var sqliteCount int64
|
||
err = m.db.Raw(fmt.Sprintf(`
|
||
SELECT COUNT(*) FROM sqlite_master
|
||
WHERE type='table' AND name='%s'
|
||
`, m.tableName)).Scan(&sqliteCount).Error
|
||
if err == nil && sqliteCount > 0 {
|
||
exists = true
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 如果查询失败,假设表不存在,尝试创建
|
||
if err != nil {
|
||
exists = false
|
||
}
|
||
|
||
if !exists {
|
||
// 创建迁移记录表(包含执行时间字段)
|
||
err = m.db.Exec(fmt.Sprintf(`
|
||
CREATE TABLE IF NOT EXISTS %s (
|
||
version VARCHAR(255) PRIMARY KEY,
|
||
description VARCHAR(255),
|
||
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||
execution_time INT COMMENT '执行耗时(ms)'
|
||
)
|
||
`, m.tableName)).Error
|
||
if err != nil {
|
||
return fmt.Errorf("failed to create migration table: %w", err)
|
||
}
|
||
} else {
|
||
// 表已存在,检查是否有 execution_time 字段(向后兼容)
|
||
// 注意:这个检查可能在某些数据库中失败,但不影响功能
|
||
// 如果字段不存在,记录执行时间时会失败,但不影响迁移执行
|
||
var hasExecutionTime bool
|
||
var columnCount int64
|
||
var checkErr error
|
||
|
||
switch m.dbType {
|
||
case "mysql":
|
||
// MySQL/MariaDB语法
|
||
checkErr = m.db.Raw(fmt.Sprintf(`
|
||
SELECT COUNT(*)
|
||
FROM information_schema.columns
|
||
WHERE table_schema = DATABASE()
|
||
AND table_name = '%s'
|
||
AND column_name = 'execution_time'
|
||
`, m.tableName)).Scan(&columnCount).Error
|
||
if checkErr == nil {
|
||
hasExecutionTime = columnCount > 0
|
||
}
|
||
case "postgres":
|
||
// PostgreSQL语法
|
||
checkErr = m.db.Raw(fmt.Sprintf(`
|
||
SELECT COUNT(*)
|
||
FROM information_schema.columns
|
||
WHERE table_schema = CURRENT_SCHEMA()
|
||
AND table_name = '%s'
|
||
AND column_name = 'execution_time'
|
||
`, m.tableName)).Scan(&columnCount).Error
|
||
if checkErr == nil {
|
||
hasExecutionTime = columnCount > 0
|
||
}
|
||
case "sqlite":
|
||
// SQLite不支持information_schema,跳过检查
|
||
hasExecutionTime = false
|
||
default:
|
||
// 兼容模式:尝试MySQL语法
|
||
checkErr = m.db.Raw(fmt.Sprintf(`
|
||
SELECT COUNT(*)
|
||
FROM information_schema.columns
|
||
WHERE table_schema = DATABASE()
|
||
AND table_name = '%s'
|
||
AND column_name = 'execution_time'
|
||
`, m.tableName)).Scan(&columnCount).Error
|
||
if checkErr == nil {
|
||
hasExecutionTime = columnCount > 0
|
||
}
|
||
}
|
||
|
||
if !hasExecutionTime {
|
||
// 尝试添加字段(如果失败不影响功能)
|
||
// 注意:SQLite的ALTER TABLE ADD COLUMN语法略有不同,但GORM会处理
|
||
_ = m.db.Exec(fmt.Sprintf(`
|
||
ALTER TABLE %s
|
||
ADD COLUMN execution_time INT
|
||
`, m.tableName))
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// getAppliedMigrations 获取已应用的迁移版本
|
||
func (m *Migrator) getAppliedMigrations() (map[string]bool, error) {
|
||
if err := m.initTable(); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var versions []string
|
||
err := m.db.Table(m.tableName).Select("version").Pluck("version", &versions).Error
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get applied migrations: %w", err)
|
||
}
|
||
|
||
applied := make(map[string]bool)
|
||
for _, v := range versions {
|
||
applied[v] = true
|
||
}
|
||
|
||
return applied, nil
|
||
}
|
||
|
||
// recordMigration 记录迁移
|
||
func (m *Migrator) recordMigration(version, description string, isUp bool, executionTime ...int) error {
|
||
if err := m.initTable(); err != nil {
|
||
return err
|
||
}
|
||
|
||
if isUp {
|
||
// 记录迁移(包含执行时间,如果提供了)
|
||
var err error
|
||
if len(executionTime) > 0 && executionTime[0] > 0 {
|
||
// 尝试插入执行时间(如果字段存在)
|
||
err = m.db.Exec(fmt.Sprintf(`
|
||
INSERT INTO %s (version, description, applied_at, execution_time)
|
||
VALUES (?, ?, ?, ?)
|
||
`, m.tableName), version, description, time.Now(), executionTime[0]).Error
|
||
if err != nil {
|
||
// 如果失败(可能是字段不存在),尝试不包含执行时间
|
||
err = m.db.Exec(fmt.Sprintf(`
|
||
INSERT INTO %s (version, description, applied_at)
|
||
VALUES (?, ?, ?)
|
||
`, m.tableName), version, description, time.Now()).Error
|
||
}
|
||
} else {
|
||
// 不包含执行时间
|
||
err = m.db.Exec(fmt.Sprintf(`
|
||
INSERT INTO %s (version, description, applied_at)
|
||
VALUES (?, ?, ?)
|
||
`, m.tableName), version, description, time.Now()).Error
|
||
}
|
||
if err != nil {
|
||
return fmt.Errorf("failed to record migration: %w", err)
|
||
}
|
||
} else {
|
||
// 删除迁移记录
|
||
err := m.db.Exec(fmt.Sprintf("DELETE FROM %s WHERE version = ?", m.tableName), version).Error
|
||
if err != nil {
|
||
return fmt.Errorf("failed to remove migration record: %w", err)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// Up 执行所有未应用的迁移
|
||
func (m *Migrator) Up() error {
|
||
if err := m.initTable(); err != nil {
|
||
return err
|
||
}
|
||
|
||
applied, err := m.getAppliedMigrations()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 排序迁移
|
||
sort.Slice(m.migrations, func(i, j int) bool {
|
||
return m.migrations[i].Version < m.migrations[j].Version
|
||
})
|
||
|
||
// 执行未应用的迁移
|
||
for _, migration := range m.migrations {
|
||
if applied[migration.Version] {
|
||
continue
|
||
}
|
||
|
||
if migration.Up == nil {
|
||
return fmt.Errorf("migration %s has no Up function", migration.Version)
|
||
}
|
||
|
||
// 记录开始时间
|
||
startTime := time.Now()
|
||
|
||
// 开始事务
|
||
tx := m.db.Begin()
|
||
if tx.Error != nil {
|
||
return fmt.Errorf("failed to begin transaction: %w", tx.Error)
|
||
}
|
||
|
||
// 执行迁移
|
||
if err := migration.Up(tx); err != nil {
|
||
tx.Rollback()
|
||
return fmt.Errorf("failed to apply migration %s: %w", migration.Version, err)
|
||
}
|
||
|
||
// 计算执行时间(毫秒)
|
||
executionTime := int(time.Since(startTime).Milliseconds())
|
||
|
||
// 记录迁移(包含执行时间)
|
||
if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, true, executionTime); err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
|
||
// 提交事务
|
||
if err := tx.Commit().Error; err != nil {
|
||
return fmt.Errorf("failed to commit migration %s: %w", migration.Version, err)
|
||
}
|
||
|
||
fmt.Printf("Applied migration: %s - %s (耗时: %dms)\n", migration.Version, migration.Description, executionTime)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// Down 回滚最后一个迁移
|
||
func (m *Migrator) Down() error {
|
||
if err := m.initTable(); err != nil {
|
||
return err
|
||
}
|
||
|
||
applied, err := m.getAppliedMigrations()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 排序迁移(倒序)
|
||
sort.Slice(m.migrations, func(i, j int) bool {
|
||
return m.migrations[i].Version > m.migrations[j].Version
|
||
})
|
||
|
||
// 找到最后一个已应用的迁移并回滚
|
||
for _, migration := range m.migrations {
|
||
if !applied[migration.Version] {
|
||
continue
|
||
}
|
||
|
||
if migration.Down == nil {
|
||
return fmt.Errorf("migration %s has no Down function", migration.Version)
|
||
}
|
||
|
||
// 开始事务
|
||
tx := m.db.Begin()
|
||
if tx.Error != nil {
|
||
return fmt.Errorf("failed to begin transaction: %w", tx.Error)
|
||
}
|
||
|
||
// 执行回滚
|
||
if err := migration.Down(tx); err != nil {
|
||
tx.Rollback()
|
||
return fmt.Errorf("failed to rollback migration %s: %w", migration.Version, err)
|
||
}
|
||
|
||
// 删除迁移记录
|
||
if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, false); err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
|
||
// 提交事务
|
||
if err := tx.Commit().Error; err != nil {
|
||
return fmt.Errorf("failed to commit rollback %s: %w", migration.Version, err)
|
||
}
|
||
|
||
fmt.Printf("Rolled back migration: %s - %s\n", migration.Version, migration.Description)
|
||
return nil
|
||
}
|
||
|
||
return fmt.Errorf("no migrations to rollback")
|
||
}
|
||
|
||
// recordMigrationWithDB 使用指定的数据库连接记录迁移
|
||
func (m *Migrator) recordMigrationWithDB(db *gorm.DB, version, description string, isUp bool, executionTime ...int) error {
|
||
if isUp {
|
||
var err error
|
||
if len(executionTime) > 0 && executionTime[0] > 0 {
|
||
// 尝试插入执行时间(如果字段存在)
|
||
err = db.Exec(fmt.Sprintf(`
|
||
INSERT INTO %s (version, description, applied_at, execution_time)
|
||
VALUES (?, ?, ?, ?)
|
||
`, m.tableName), version, description, time.Now(), executionTime[0]).Error
|
||
if err != nil {
|
||
// 如果失败(可能是字段不存在),尝试不包含执行时间
|
||
err = db.Exec(fmt.Sprintf(`
|
||
INSERT INTO %s (version, description, applied_at)
|
||
VALUES (?, ?, ?)
|
||
`, m.tableName), version, description, time.Now()).Error
|
||
}
|
||
} else {
|
||
// 不包含执行时间
|
||
err = db.Exec(fmt.Sprintf(`
|
||
INSERT INTO %s (version, description, applied_at)
|
||
VALUES (?, ?, ?)
|
||
`, m.tableName), version, description, time.Now()).Error
|
||
}
|
||
if err != nil {
|
||
return fmt.Errorf("failed to record migration: %w", err)
|
||
}
|
||
} else {
|
||
err := db.Exec(fmt.Sprintf("DELETE FROM %s WHERE version = ?", m.tableName), version).Error
|
||
if err != nil {
|
||
return fmt.Errorf("failed to remove migration record: %w", err)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// Status 查看迁移状态
|
||
func (m *Migrator) Status() ([]MigrationStatus, error) {
|
||
if err := m.initTable(); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
applied, err := m.getAppliedMigrations()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 排序迁移
|
||
sort.Slice(m.migrations, func(i, j int) bool {
|
||
return m.migrations[i].Version < m.migrations[j].Version
|
||
})
|
||
|
||
status := make([]MigrationStatus, 0, len(m.migrations))
|
||
for _, migration := range m.migrations {
|
||
status = append(status, MigrationStatus{
|
||
Version: migration.Version,
|
||
Description: migration.Description,
|
||
Applied: applied[migration.Version],
|
||
})
|
||
}
|
||
|
||
return status, nil
|
||
}
|
||
|
||
// MigrationStatus 迁移状态
|
||
type MigrationStatus struct {
|
||
Version string
|
||
Description string
|
||
Applied bool
|
||
}
|
||
|
||
// splitSQL 分割SQL语句,处理多行SQL、注释等
|
||
// 支持单行注释(--)、多行注释(/* */)、按分号分割语句
|
||
func splitSQL(content string) []string {
|
||
var statements []string
|
||
var current strings.Builder
|
||
|
||
lines := strings.Split(content, "\n")
|
||
inMultiLineComment := false
|
||
|
||
for _, line := range lines {
|
||
trimmedLine := strings.TrimSpace(line)
|
||
|
||
// 跳过空行
|
||
if trimmedLine == "" {
|
||
continue
|
||
}
|
||
|
||
// 处理多行注释
|
||
if strings.HasPrefix(trimmedLine, "/*") {
|
||
inMultiLineComment = true
|
||
}
|
||
if strings.HasSuffix(trimmedLine, "*/") {
|
||
inMultiLineComment = false
|
||
continue
|
||
}
|
||
if inMultiLineComment {
|
||
continue
|
||
}
|
||
|
||
// 跳过单行注释
|
||
if strings.HasPrefix(trimmedLine, "--") {
|
||
continue
|
||
}
|
||
|
||
// 添加到当前语句
|
||
current.WriteString(line)
|
||
current.WriteString("\n")
|
||
|
||
// 检查是否是完整语句(以分号结尾)
|
||
if strings.HasSuffix(trimmedLine, ";") {
|
||
stmt := strings.TrimSpace(current.String())
|
||
if stmt != "" && !strings.HasPrefix(stmt, "--") {
|
||
statements = append(statements, stmt)
|
||
}
|
||
current.Reset()
|
||
}
|
||
}
|
||
|
||
// 添加最后一个语句(如果没有分号结尾)
|
||
if current.Len() > 0 {
|
||
stmt := strings.TrimSpace(current.String())
|
||
if stmt != "" && !strings.HasPrefix(stmt, "--") {
|
||
statements = append(statements, stmt)
|
||
}
|
||
}
|
||
|
||
return statements
|
||
}
|
||
|
||
// parseMigrationFileName 解析迁移文件名,支持多种格式
|
||
// 格式1: 数字前缀 - 01_init_schema.sql
|
||
// 格式2: 时间戳 - 20240101000001_create_users.sql
|
||
// 格式3: 带.up后缀 - 20240101000001_create_users.up.sql
|
||
// 返回: (version, description, error)
|
||
func parseMigrationFileName(baseName string) (string, string, error) {
|
||
// 移除扩展名
|
||
nameWithoutExt := strings.TrimSuffix(baseName, filepath.Ext(baseName))
|
||
// 移除 .up 后缀(如果存在)
|
||
nameWithoutExt = strings.TrimSuffix(nameWithoutExt, ".up")
|
||
|
||
// 解析版本号和描述
|
||
parts := strings.SplitN(nameWithoutExt, "_", 2)
|
||
if len(parts) < 2 {
|
||
// 如果只有一个部分,尝试作为版本号(向后兼容)
|
||
return nameWithoutExt, baseName, nil
|
||
}
|
||
|
||
version := parts[0]
|
||
description := strings.Join(parts[1:], "_")
|
||
|
||
return version, description, nil
|
||
}
|
||
|
||
// LoadMigrationsFromFiles 从文件系统加载迁移文件
|
||
// dir: 迁移文件目录
|
||
// pattern: 文件命名模式,例如 "*.sql" 或 "*.up.sql"
|
||
// 文件命名格式支持:
|
||
// - 数字前缀: 01_init_schema.sql
|
||
// - 时间戳: 20240101000001_create_users.sql
|
||
// - 带.up后缀: 20240101000001_create_users.up.sql
|
||
func LoadMigrationsFromFiles(dir string, pattern string) ([]Migration, error) {
|
||
files, err := filepath.Glob(filepath.Join(dir, pattern))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to glob migration files: %w", err)
|
||
}
|
||
|
||
migrations := make([]Migration, 0)
|
||
for _, file := range files {
|
||
baseName := filepath.Base(file)
|
||
|
||
// 跳过 .down.sql 文件(会在处理 .up.sql 或 .sql 时自动加载)
|
||
if strings.HasSuffix(baseName, ".down.sql") {
|
||
continue
|
||
}
|
||
|
||
// 解析版本号和描述
|
||
version, description, err := parseMigrationFileName(baseName)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("invalid migration file name format: %s: %w", baseName, err)
|
||
}
|
||
|
||
// 读取文件内容
|
||
content, err := os.ReadFile(file)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to read migration file %s: %w", file, err)
|
||
}
|
||
|
||
sqlContent := string(content)
|
||
|
||
// 查找对应的 down 文件
|
||
downFile := strings.Replace(file, ".up.sql", ".down.sql", 1)
|
||
downFile = strings.Replace(downFile, ".sql", ".down.sql", 1)
|
||
var downSQL string
|
||
if downContent, err := os.ReadFile(downFile); err == nil {
|
||
downSQL = string(downContent)
|
||
}
|
||
|
||
// 创建迁移,使用 SQL 分割功能
|
||
migration := Migration{
|
||
Version: version,
|
||
Description: description,
|
||
Up: func(db *gorm.DB) error {
|
||
// 分割 SQL 语句
|
||
statements := splitSQL(sqlContent)
|
||
if len(statements) == 0 {
|
||
return nil // 空文件,跳过
|
||
}
|
||
|
||
// 执行每个 SQL 语句
|
||
// 注意:某些 DDL 语句(如 CREATE TABLE)在某些数据库中会隐式提交事务
|
||
// 因此这里不使用事务,而是逐个执行
|
||
for i, stmt := range statements {
|
||
stmt = strings.TrimSpace(stmt)
|
||
if stmt == "" {
|
||
continue
|
||
}
|
||
|
||
if err := db.Exec(stmt).Error; err != nil {
|
||
// 如果是表已存在的错误,记录警告但继续执行(向后兼容)
|
||
errStr := err.Error()
|
||
if strings.Contains(errStr, "already exists") ||
|
||
strings.Contains(errStr, "Duplicate") ||
|
||
strings.Contains(errStr, "duplicate") {
|
||
fmt.Printf("Warning: SQL statement %d in migration %s: %v\n", i+1, version, err)
|
||
continue
|
||
}
|
||
return fmt.Errorf("failed to execute SQL statement %d in migration %s: %w\nSQL: %s", i+1, version, err, stmt)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
},
|
||
}
|
||
|
||
if downSQL != "" {
|
||
migration.Down = func(db *gorm.DB) error {
|
||
// 分割 SQL 语句
|
||
statements := splitSQL(downSQL)
|
||
if len(statements) == 0 {
|
||
return nil
|
||
}
|
||
|
||
// 执行每个 SQL 语句
|
||
for i, stmt := range statements {
|
||
stmt = strings.TrimSpace(stmt)
|
||
if stmt == "" {
|
||
continue
|
||
}
|
||
|
||
if err := db.Exec(stmt).Error; err != nil {
|
||
return fmt.Errorf("failed to execute SQL statement %d in rollback %s: %w\nSQL: %s", i+1, version, err, stmt)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
}
|
||
|
||
migrations = append(migrations, migration)
|
||
}
|
||
|
||
// 按版本号排序(支持数字和时间戳混合排序)
|
||
sort.Slice(migrations, func(i, j int) bool {
|
||
vi, vj := migrations[i].Version, migrations[j].Version
|
||
// 尝试按数字排序(如果是数字前缀)
|
||
if viNum, err1 := strconv.Atoi(vi); err1 == nil {
|
||
if vjNum, err2 := strconv.Atoi(vj); err2 == nil {
|
||
return viNum < vjNum
|
||
}
|
||
}
|
||
// 否则按字符串排序
|
||
return vi < vj
|
||
})
|
||
|
||
return migrations, nil
|
||
}
|
||
|
||
// GenerateVersion 生成迁移版本号(基于时间戳)
|
||
func GenerateVersion() string {
|
||
return strconv.FormatInt(time.Now().Unix(), 10)
|
||
}
|
||
|
||
// Reset 重置所有迁移(清空迁移记录表)
|
||
// confirm: 确认标志,必须为true才能执行重置
|
||
// 注意:此操作会清空所有迁移记录,但不会回滚已执行的迁移操作
|
||
// 如果需要回滚迁移,请先使用Down方法逐个回滚
|
||
func (m *Migrator) Reset(confirm bool) error {
|
||
if !confirm {
|
||
return fmt.Errorf("reset operation requires explicit confirmation (confirm must be true)")
|
||
}
|
||
|
||
if err := m.initTable(); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 获取已应用的迁移数量
|
||
applied, err := m.getAppliedMigrations()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
count := len(applied)
|
||
if count == 0 {
|
||
fmt.Println("No migrations to reset")
|
||
return nil
|
||
}
|
||
|
||
// 清空迁移记录表
|
||
err = m.db.Exec(fmt.Sprintf("DELETE FROM %s", m.tableName)).Error
|
||
if err != nil {
|
||
return fmt.Errorf("failed to reset migrations: %w", err)
|
||
}
|
||
|
||
fmt.Printf("Reset completed: %d migration record(s) cleared\n", count)
|
||
fmt.Println("WARNING: This only clears migration records, not the actual database changes.")
|
||
fmt.Println("If you need to rollback database changes, use Down() method before reset.")
|
||
|
||
return nil
|
||
}
|
||
|
||
// ResetWithConfirm 交互式重置所有迁移(带确认提示)
|
||
// 会提示用户数据会被清空,需要输入确认才能执行
|
||
func (m *Migrator) ResetWithConfirm() error {
|
||
if err := m.initTable(); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 获取已应用的迁移数量
|
||
applied, err := m.getAppliedMigrations()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
count := len(applied)
|
||
if count == 0 {
|
||
fmt.Println("No migrations to reset")
|
||
return nil
|
||
}
|
||
|
||
// 显示警告信息
|
||
fmt.Println("=" + strings.Repeat("=", 60) + "=")
|
||
fmt.Println("WARNING: This operation will clear all migration records!")
|
||
fmt.Println("=" + strings.Repeat("=", 60) + "=")
|
||
fmt.Printf("This will clear %d migration record(s) from the database.\n", count)
|
||
fmt.Println()
|
||
fmt.Println("IMPORTANT NOTES:")
|
||
fmt.Println(" 1. This operation only clears migration records, NOT the actual database changes.")
|
||
fmt.Println(" 2. If you need to rollback database changes, use Down() method before reset.")
|
||
fmt.Println(" 3. After reset, you may need to re-run migrations if the database structure changed.")
|
||
fmt.Println()
|
||
fmt.Print("Type 'RESET' (all caps) to confirm, or anything else to cancel: ")
|
||
|
||
// 读取用户输入
|
||
reader := bufio.NewReader(os.Stdin)
|
||
input, err := reader.ReadString('\n')
|
||
if err != nil {
|
||
return fmt.Errorf("failed to read confirmation: %w", err)
|
||
}
|
||
|
||
// 去除换行符和空格
|
||
input = strings.TrimSpace(input)
|
||
|
||
// 检查确认
|
||
if input != "RESET" {
|
||
fmt.Println("Reset operation cancelled.")
|
||
return nil
|
||
}
|
||
|
||
// 执行重置
|
||
return m.Reset(true)
|
||
}
|
||
|
||
// ResetAll 重置所有迁移并回滚所有已应用的迁移
|
||
// confirm: 确认标志,必须为true才能执行
|
||
// 注意:此操作会回滚所有已应用的迁移,然后清空迁移记录
|
||
func (m *Migrator) ResetAll(confirm bool) error {
|
||
if !confirm {
|
||
return fmt.Errorf("reset all operation requires explicit confirmation (confirm must be true)")
|
||
}
|
||
|
||
if err := m.initTable(); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 获取已应用的迁移
|
||
applied, err := m.getAppliedMigrations()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
count := len(applied)
|
||
if count == 0 {
|
||
fmt.Println("No migrations to reset")
|
||
return nil
|
||
}
|
||
|
||
// 回滚所有已应用的迁移
|
||
fmt.Printf("Rolling back %d migration(s)...\n", count)
|
||
|
||
// 排序迁移(倒序)
|
||
sort.Slice(m.migrations, func(i, j int) bool {
|
||
return m.migrations[i].Version > m.migrations[j].Version
|
||
})
|
||
|
||
// 回滚所有已应用的迁移
|
||
for _, migration := range m.migrations {
|
||
if !applied[migration.Version] {
|
||
continue
|
||
}
|
||
|
||
if migration.Down == nil {
|
||
fmt.Printf("Warning: Migration %s has no Down function, skipping rollback\n", migration.Version)
|
||
continue
|
||
}
|
||
|
||
// 开始事务
|
||
tx := m.db.Begin()
|
||
if tx.Error != nil {
|
||
return fmt.Errorf("failed to begin transaction: %w", tx.Error)
|
||
}
|
||
|
||
// 执行回滚
|
||
if err := migration.Down(tx); err != nil {
|
||
tx.Rollback()
|
||
return fmt.Errorf("failed to rollback migration %s: %w", migration.Version, err)
|
||
}
|
||
|
||
// 删除迁移记录
|
||
if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, false); err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
|
||
// 提交事务
|
||
if err := tx.Commit().Error; err != nil {
|
||
return fmt.Errorf("failed to commit rollback %s: %w", migration.Version, err)
|
||
}
|
||
|
||
fmt.Printf("Rolled back migration: %s - %s\n", migration.Version, migration.Description)
|
||
}
|
||
|
||
fmt.Printf("Reset all completed: %d migration(s) rolled back and records cleared\n", count)
|
||
return nil
|
||
}
|
||
|
||
// ResetAllWithConfirm 交互式重置所有迁移并回滚(带确认提示)
|
||
// 会提示用户数据会被清空,需要输入确认才能执行
|
||
func (m *Migrator) ResetAllWithConfirm() error {
|
||
if err := m.initTable(); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 获取已应用的迁移数量
|
||
applied, err := m.getAppliedMigrations()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
count := len(applied)
|
||
if count == 0 {
|
||
fmt.Println("No migrations to reset")
|
||
return nil
|
||
}
|
||
|
||
// 显示警告信息
|
||
fmt.Println("=" + strings.Repeat("=", 60) + "=")
|
||
fmt.Println("WARNING: This operation will rollback ALL migrations and clear records!")
|
||
fmt.Println("=" + strings.Repeat("=", 60) + "=")
|
||
fmt.Printf("This will rollback %d migration(s) and clear all migration records.\n", count)
|
||
fmt.Println()
|
||
fmt.Println("IMPORTANT NOTES:")
|
||
fmt.Println(" 1. This operation will EXECUTE Down() functions for all applied migrations.")
|
||
fmt.Println(" 2. All database changes made by migrations will be REVERTED.")
|
||
fmt.Println(" 3. This operation CANNOT be undone!")
|
||
fmt.Println(" 4. Make sure you have a database backup before proceeding.")
|
||
fmt.Println()
|
||
fmt.Print("Type 'RESET ALL' (all caps) to confirm, or anything else to cancel: ")
|
||
|
||
// 读取用户输入
|
||
reader := bufio.NewReader(os.Stdin)
|
||
input, err := reader.ReadString('\n')
|
||
if err != nil {
|
||
return fmt.Errorf("failed to read confirmation: %w", err)
|
||
}
|
||
|
||
// 去除换行符和空格
|
||
input = strings.TrimSpace(input)
|
||
|
||
// 检查确认
|
||
if input != "RESET ALL" {
|
||
fmt.Println("Reset all operation cancelled.")
|
||
return nil
|
||
}
|
||
|
||
// 执行重置
|
||
return m.ResetAll(true)
|
||
}
|