修复迁移时,数据库未指定的情况下数据库脚本混乱的问题
This commit is contained in:
@@ -26,6 +26,7 @@ type Migrator struct {
|
||||
db *gorm.DB
|
||||
migrations []Migration
|
||||
tableName string
|
||||
dbType string // 数据库类型: mysql, postgres, sqlite
|
||||
}
|
||||
|
||||
// NewMigrator 创建新的迁移器
|
||||
@@ -41,6 +42,25 @@ func NewMigrator(db *gorm.DB, tableName ...string) *Migrator {
|
||||
db: db,
|
||||
migrations: make([]Migration, 0),
|
||||
tableName: table,
|
||||
dbType: "", // 未指定时为空,会使用兼容模式
|
||||
}
|
||||
}
|
||||
|
||||
// NewMigratorWithType 创建新的迁移器(指定数据库类型,性能更好)
|
||||
// db: GORM数据库连接
|
||||
// dbType: 数据库类型 ("mysql", "postgres", "sqlite")
|
||||
// tableName: 存储迁移记录的表名,默认为 "schema_migrations"
|
||||
func NewMigratorWithType(db *gorm.DB, dbType string, tableName ...string) *Migrator {
|
||||
table := "schema_migrations"
|
||||
if len(tableName) > 0 && tableName[0] != "" {
|
||||
table = tableName[0]
|
||||
}
|
||||
|
||||
return &Migrator{
|
||||
db: db,
|
||||
migrations: make([]Migration, 0),
|
||||
tableName: table,
|
||||
dbType: dbType,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,18 +76,78 @@ func (m *Migrator) AddMigrations(migrations ...Migration) {
|
||||
|
||||
// initTable 初始化迁移记录表
|
||||
func (m *Migrator) initTable() error {
|
||||
// 检查表是否存在
|
||||
// 检查表是否存在(根据数据库类型使用对应的SQL,性能更好)
|
||||
var exists bool
|
||||
err := m.db.Raw(fmt.Sprintf(`
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_schema = CURRENT_SCHEMA()
|
||||
AND table_name = '%s'
|
||||
)
|
||||
`, m.tableName)).Scan(&exists).Error
|
||||
var err error
|
||||
|
||||
switch m.dbType {
|
||||
case "mysql":
|
||||
// MySQL/MariaDB语法
|
||||
var count int64
|
||||
err = m.db.Raw(fmt.Sprintf(`
|
||||
SELECT COUNT(*) FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '%s'
|
||||
`, m.tableName)).Scan(&count).Error
|
||||
if err == nil {
|
||||
exists = count > 0
|
||||
}
|
||||
case "postgres":
|
||||
// PostgreSQL语法
|
||||
err = m.db.Raw(fmt.Sprintf(`
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.tables
|
||||
WHERE table_schema = CURRENT_SCHEMA()
|
||||
AND table_name = '%s'
|
||||
)
|
||||
`, m.tableName)).Scan(&exists).Error
|
||||
case "sqlite":
|
||||
// SQLite语法
|
||||
var count int64
|
||||
err = m.db.Raw(fmt.Sprintf(`
|
||||
SELECT COUNT(*) FROM sqlite_master
|
||||
WHERE type='table' AND name='%s'
|
||||
`, m.tableName)).Scan(&count).Error
|
||||
if err == nil {
|
||||
exists = count > 0
|
||||
}
|
||||
default:
|
||||
// 未指定数据库类型时,使用兼容模式(向后兼容)
|
||||
// 按顺序尝试不同数据库的语法
|
||||
var count int64
|
||||
err = m.db.Raw(fmt.Sprintf(`
|
||||
SELECT COUNT(*) FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '%s'
|
||||
`, m.tableName)).Scan(&count).Error
|
||||
if err == nil && count > 0 {
|
||||
exists = true
|
||||
} else {
|
||||
var pgExists bool
|
||||
err = m.db.Raw(fmt.Sprintf(`
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.tables
|
||||
WHERE table_schema = CURRENT_SCHEMA()
|
||||
AND table_name = '%s'
|
||||
)
|
||||
`, m.tableName)).Scan(&pgExists).Error
|
||||
if err == nil {
|
||||
exists = pgExists
|
||||
} else {
|
||||
var sqliteCount int64
|
||||
err = m.db.Raw(fmt.Sprintf(`
|
||||
SELECT COUNT(*) FROM sqlite_master
|
||||
WHERE type='table' AND name='%s'
|
||||
`, m.tableName)).Scan(&sqliteCount).Error
|
||||
if err == nil && sqliteCount > 0 {
|
||||
exists = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果查询失败,假设表不存在,尝试创建
|
||||
if err != nil {
|
||||
// 如果查询失败,可能是SQLite或其他数据库,尝试直接创建
|
||||
exists = false
|
||||
}
|
||||
|
||||
@@ -89,19 +169,57 @@ func (m *Migrator) initTable() error {
|
||||
// 注意:这个检查可能在某些数据库中失败,但不影响功能
|
||||
// 如果字段不存在,记录执行时间时会失败,但不影响迁移执行
|
||||
var hasExecutionTime bool
|
||||
checkSQL := fmt.Sprintf(`
|
||||
SELECT COUNT(*) > 0
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = CURRENT_SCHEMA()
|
||||
AND table_name = '%s'
|
||||
AND column_name = 'execution_time'
|
||||
`, m.tableName)
|
||||
err = m.db.Raw(checkSQL).Scan(&hasExecutionTime).Error
|
||||
if err == nil && !hasExecutionTime {
|
||||
var columnCount int64
|
||||
var checkErr error
|
||||
|
||||
switch m.dbType {
|
||||
case "mysql":
|
||||
// MySQL/MariaDB语法
|
||||
checkErr = m.db.Raw(fmt.Sprintf(`
|
||||
SELECT COUNT(*)
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '%s'
|
||||
AND column_name = 'execution_time'
|
||||
`, m.tableName)).Scan(&columnCount).Error
|
||||
if checkErr == nil {
|
||||
hasExecutionTime = columnCount > 0
|
||||
}
|
||||
case "postgres":
|
||||
// PostgreSQL语法
|
||||
checkErr = m.db.Raw(fmt.Sprintf(`
|
||||
SELECT COUNT(*)
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = CURRENT_SCHEMA()
|
||||
AND table_name = '%s'
|
||||
AND column_name = 'execution_time'
|
||||
`, m.tableName)).Scan(&columnCount).Error
|
||||
if checkErr == nil {
|
||||
hasExecutionTime = columnCount > 0
|
||||
}
|
||||
case "sqlite":
|
||||
// SQLite不支持information_schema,跳过检查
|
||||
hasExecutionTime = false
|
||||
default:
|
||||
// 兼容模式:尝试MySQL语法
|
||||
checkErr = m.db.Raw(fmt.Sprintf(`
|
||||
SELECT COUNT(*)
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '%s'
|
||||
AND column_name = 'execution_time'
|
||||
`, m.tableName)).Scan(&columnCount).Error
|
||||
if checkErr == nil {
|
||||
hasExecutionTime = columnCount > 0
|
||||
}
|
||||
}
|
||||
|
||||
if !hasExecutionTime {
|
||||
// 尝试添加字段(如果失败不影响功能)
|
||||
// 注意:SQLite的ALTER TABLE ADD COLUMN语法略有不同,但GORM会处理
|
||||
_ = m.db.Exec(fmt.Sprintf(`
|
||||
ALTER TABLE %s
|
||||
ADD COLUMN execution_time INT COMMENT '执行耗时(ms)'
|
||||
ADD COLUMN execution_time INT
|
||||
`, m.tableName))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user