修复迁移时,数据库未指定的情况下数据库脚本混乱的问题

This commit is contained in:
2025-12-06 22:03:57 +08:00
parent 6547e7bca8
commit b66f345281
7 changed files with 194 additions and 77 deletions

View File

@@ -26,6 +26,7 @@ type Migrator struct {
db *gorm.DB
migrations []Migration
tableName string
dbType string // 数据库类型: mysql, postgres, sqlite
}
// NewMigrator 创建新的迁移器
@@ -41,6 +42,25 @@ func NewMigrator(db *gorm.DB, tableName ...string) *Migrator {
db: db,
migrations: make([]Migration, 0),
tableName: table,
dbType: "", // 未指定时为空,会使用兼容模式
}
}
// NewMigratorWithType 创建新的迁移器(指定数据库类型,性能更好)
// db: GORM数据库连接
// dbType: 数据库类型 ("mysql", "postgres", "sqlite")
// tableName: 存储迁移记录的表名,默认为 "schema_migrations"
func NewMigratorWithType(db *gorm.DB, dbType string, tableName ...string) *Migrator {
table := "schema_migrations"
if len(tableName) > 0 && tableName[0] != "" {
table = tableName[0]
}
return &Migrator{
db: db,
migrations: make([]Migration, 0),
tableName: table,
dbType: dbType,
}
}
@@ -56,18 +76,78 @@ func (m *Migrator) AddMigrations(migrations ...Migration) {
// initTable 初始化迁移记录表
func (m *Migrator) initTable() error {
// 检查表是否存在
// 检查表是否存在根据数据库类型使用对应的SQL性能更好
var exists bool
err := m.db.Raw(fmt.Sprintf(`
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_schema = CURRENT_SCHEMA()
AND table_name = '%s'
)
`, m.tableName)).Scan(&exists).Error
var err error
switch m.dbType {
case "mysql":
// MySQL/MariaDB语法
var count int64
err = m.db.Raw(fmt.Sprintf(`
SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_name = '%s'
`, m.tableName)).Scan(&count).Error
if err == nil {
exists = count > 0
}
case "postgres":
// PostgreSQL语法
err = m.db.Raw(fmt.Sprintf(`
SELECT EXISTS (
SELECT 1 FROM information_schema.tables
WHERE table_schema = CURRENT_SCHEMA()
AND table_name = '%s'
)
`, m.tableName)).Scan(&exists).Error
case "sqlite":
// SQLite语法
var count int64
err = m.db.Raw(fmt.Sprintf(`
SELECT COUNT(*) FROM sqlite_master
WHERE type='table' AND name='%s'
`, m.tableName)).Scan(&count).Error
if err == nil {
exists = count > 0
}
default:
// 未指定数据库类型时,使用兼容模式(向后兼容)
// 按顺序尝试不同数据库的语法
var count int64
err = m.db.Raw(fmt.Sprintf(`
SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_name = '%s'
`, m.tableName)).Scan(&count).Error
if err == nil && count > 0 {
exists = true
} else {
var pgExists bool
err = m.db.Raw(fmt.Sprintf(`
SELECT EXISTS (
SELECT 1 FROM information_schema.tables
WHERE table_schema = CURRENT_SCHEMA()
AND table_name = '%s'
)
`, m.tableName)).Scan(&pgExists).Error
if err == nil {
exists = pgExists
} else {
var sqliteCount int64
err = m.db.Raw(fmt.Sprintf(`
SELECT COUNT(*) FROM sqlite_master
WHERE type='table' AND name='%s'
`, m.tableName)).Scan(&sqliteCount).Error
if err == nil && sqliteCount > 0 {
exists = true
}
}
}
}
// 如果查询失败,假设表不存在,尝试创建
if err != nil {
// 如果查询失败可能是SQLite或其他数据库尝试直接创建
exists = false
}
@@ -89,19 +169,57 @@ func (m *Migrator) initTable() error {
// 注意:这个检查可能在某些数据库中失败,但不影响功能
// 如果字段不存在,记录执行时间时会失败,但不影响迁移执行
var hasExecutionTime bool
checkSQL := fmt.Sprintf(`
SELECT COUNT(*) > 0
FROM information_schema.columns
WHERE table_schema = CURRENT_SCHEMA()
AND table_name = '%s'
AND column_name = 'execution_time'
`, m.tableName)
err = m.db.Raw(checkSQL).Scan(&hasExecutionTime).Error
if err == nil && !hasExecutionTime {
var columnCount int64
var checkErr error
switch m.dbType {
case "mysql":
// MySQL/MariaDB语法
checkErr = m.db.Raw(fmt.Sprintf(`
SELECT COUNT(*)
FROM information_schema.columns
WHERE table_schema = DATABASE()
AND table_name = '%s'
AND column_name = 'execution_time'
`, m.tableName)).Scan(&columnCount).Error
if checkErr == nil {
hasExecutionTime = columnCount > 0
}
case "postgres":
// PostgreSQL语法
checkErr = m.db.Raw(fmt.Sprintf(`
SELECT COUNT(*)
FROM information_schema.columns
WHERE table_schema = CURRENT_SCHEMA()
AND table_name = '%s'
AND column_name = 'execution_time'
`, m.tableName)).Scan(&columnCount).Error
if checkErr == nil {
hasExecutionTime = columnCount > 0
}
case "sqlite":
// SQLite不支持information_schema跳过检查
hasExecutionTime = false
default:
// 兼容模式尝试MySQL语法
checkErr = m.db.Raw(fmt.Sprintf(`
SELECT COUNT(*)
FROM information_schema.columns
WHERE table_schema = DATABASE()
AND table_name = '%s'
AND column_name = 'execution_time'
`, m.tableName)).Scan(&columnCount).Error
if checkErr == nil {
hasExecutionTime = columnCount > 0
}
}
if !hasExecutionTime {
// 尝试添加字段(如果失败不影响功能)
// 注意SQLite的ALTER TABLE ADD COLUMN语法略有不同但GORM会处理
_ = m.db.Exec(fmt.Sprintf(`
ALTER TABLE %s
ADD COLUMN execution_time INT COMMENT '执行耗时(ms)'
ADD COLUMN execution_time INT
`, m.tableName))
}
}