修复迁移时,数据库未指定的情况下数据库脚本混乱的问题
This commit is contained in:
@@ -34,14 +34,14 @@ func RunMigrationsFromConfig(configFile, migrationsDir string) error {
|
||||
// RunMigrationsFromConfigWithCommand 从配置文件运行迁移(支持命令,黑盒模式)
|
||||
//
|
||||
// 这是最简单的迁移方式,内部自动处理:
|
||||
// - 配置加载(支持文件、环境变量、默认路径)
|
||||
// - 配置加载(支持配置文件、默认路径)
|
||||
// - 数据库连接(自动识别数据库类型)
|
||||
// - 迁移文件加载和执行
|
||||
//
|
||||
// 参数:
|
||||
// - configFile: 配置文件路径,支持:
|
||||
// - 空字符串:自动查找(config.json, ../config.json)
|
||||
// - 环境变量 DATABASE_URL:直接使用数据库URL
|
||||
// - 相对路径或绝对路径:指定配置文件路径
|
||||
// - migrationsDir: 迁移文件目录,支持:
|
||||
// - 空字符串:使用默认目录 "migrations"
|
||||
// - 相对路径或绝对路径
|
||||
@@ -57,9 +57,6 @@ func RunMigrationsFromConfig(configFile, migrationsDir string) error {
|
||||
//
|
||||
// // 指定配置和迁移目录
|
||||
// migration.RunMigrationsFromConfigWithCommand("config.json", "scripts/sql", "up")
|
||||
//
|
||||
// // 使用环境变量
|
||||
// // DATABASE_URL="mysql://..." migration.RunMigrationsFromConfigWithCommand("", "migrations", "up")
|
||||
func RunMigrationsFromConfigWithCommand(configFile, migrationsDir, command string) error {
|
||||
// 加载配置
|
||||
cfg, err := loadConfigFromFileOrEnv(configFile)
|
||||
@@ -78,8 +75,8 @@ func RunMigrationsFromConfigWithCommand(configFile, migrationsDir, command strin
|
||||
migrationsDir = "migrations"
|
||||
}
|
||||
|
||||
// 创建迁移器
|
||||
migrator := NewMigrator(db)
|
||||
// 创建迁移器(传入数据库类型,性能更好)
|
||||
migrator := NewMigratorWithType(db, cfg.Database.Type)
|
||||
|
||||
// 加载迁移文件
|
||||
migrations, err := LoadMigrationsFromFiles(migrationsDir, "*.sql")
|
||||
@@ -122,22 +119,16 @@ func RunMigrationsFromConfigWithCommand(configFile, migrationsDir, command strin
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadConfigFromFileOrEnv 从文件或环境变量加载配置
|
||||
// loadConfigFromFileOrEnv 从配置文件加载配置
|
||||
// 支持指定配置文件路径,或自动查找默认路径
|
||||
func loadConfigFromFileOrEnv(configFile string) (*config.Config, error) {
|
||||
// 优先从环境变量加载
|
||||
if dbURL := os.Getenv("DATABASE_URL"); dbURL != "" {
|
||||
return &config.Config{
|
||||
Database: &config.DatabaseConfig{
|
||||
DSN: dbURL,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 尝试从配置文件加载
|
||||
// 如果指定了配置文件路径,优先使用
|
||||
if configFile != "" {
|
||||
if _, err := os.Stat(configFile); err == nil {
|
||||
return config.LoadFromFile(configFile)
|
||||
}
|
||||
// 如果指定的文件不存在,返回错误
|
||||
return nil, fmt.Errorf("配置文件不存在: %s", configFile)
|
||||
}
|
||||
|
||||
// 尝试默认路径
|
||||
@@ -148,7 +139,7 @@ func loadConfigFromFileOrEnv(configFile string) (*config.Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("未找到配置文件,也未设置环境变量 DATABASE_URL")
|
||||
return nil, fmt.Errorf("未找到配置文件,请指定配置文件路径或确保存在以下文件之一: %v", defaultPaths)
|
||||
}
|
||||
|
||||
// connectDB 连接数据库
|
||||
|
||||
@@ -26,6 +26,7 @@ type Migrator struct {
|
||||
db *gorm.DB
|
||||
migrations []Migration
|
||||
tableName string
|
||||
dbType string // 数据库类型: mysql, postgres, sqlite
|
||||
}
|
||||
|
||||
// NewMigrator 创建新的迁移器
|
||||
@@ -41,6 +42,25 @@ func NewMigrator(db *gorm.DB, tableName ...string) *Migrator {
|
||||
db: db,
|
||||
migrations: make([]Migration, 0),
|
||||
tableName: table,
|
||||
dbType: "", // 未指定时为空,会使用兼容模式
|
||||
}
|
||||
}
|
||||
|
||||
// NewMigratorWithType 创建新的迁移器(指定数据库类型,性能更好)
|
||||
// db: GORM数据库连接
|
||||
// dbType: 数据库类型 ("mysql", "postgres", "sqlite")
|
||||
// tableName: 存储迁移记录的表名,默认为 "schema_migrations"
|
||||
func NewMigratorWithType(db *gorm.DB, dbType string, tableName ...string) *Migrator {
|
||||
table := "schema_migrations"
|
||||
if len(tableName) > 0 && tableName[0] != "" {
|
||||
table = tableName[0]
|
||||
}
|
||||
|
||||
return &Migrator{
|
||||
db: db,
|
||||
migrations: make([]Migration, 0),
|
||||
tableName: table,
|
||||
dbType: dbType,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,18 +76,78 @@ func (m *Migrator) AddMigrations(migrations ...Migration) {
|
||||
|
||||
// initTable 初始化迁移记录表
|
||||
func (m *Migrator) initTable() error {
|
||||
// 检查表是否存在
|
||||
// 检查表是否存在(根据数据库类型使用对应的SQL,性能更好)
|
||||
var exists bool
|
||||
err := m.db.Raw(fmt.Sprintf(`
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_schema = CURRENT_SCHEMA()
|
||||
AND table_name = '%s'
|
||||
)
|
||||
`, m.tableName)).Scan(&exists).Error
|
||||
var err error
|
||||
|
||||
switch m.dbType {
|
||||
case "mysql":
|
||||
// MySQL/MariaDB语法
|
||||
var count int64
|
||||
err = m.db.Raw(fmt.Sprintf(`
|
||||
SELECT COUNT(*) FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '%s'
|
||||
`, m.tableName)).Scan(&count).Error
|
||||
if err == nil {
|
||||
exists = count > 0
|
||||
}
|
||||
case "postgres":
|
||||
// PostgreSQL语法
|
||||
err = m.db.Raw(fmt.Sprintf(`
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.tables
|
||||
WHERE table_schema = CURRENT_SCHEMA()
|
||||
AND table_name = '%s'
|
||||
)
|
||||
`, m.tableName)).Scan(&exists).Error
|
||||
case "sqlite":
|
||||
// SQLite语法
|
||||
var count int64
|
||||
err = m.db.Raw(fmt.Sprintf(`
|
||||
SELECT COUNT(*) FROM sqlite_master
|
||||
WHERE type='table' AND name='%s'
|
||||
`, m.tableName)).Scan(&count).Error
|
||||
if err == nil {
|
||||
exists = count > 0
|
||||
}
|
||||
default:
|
||||
// 未指定数据库类型时,使用兼容模式(向后兼容)
|
||||
// 按顺序尝试不同数据库的语法
|
||||
var count int64
|
||||
err = m.db.Raw(fmt.Sprintf(`
|
||||
SELECT COUNT(*) FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '%s'
|
||||
`, m.tableName)).Scan(&count).Error
|
||||
if err == nil && count > 0 {
|
||||
exists = true
|
||||
} else {
|
||||
var pgExists bool
|
||||
err = m.db.Raw(fmt.Sprintf(`
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.tables
|
||||
WHERE table_schema = CURRENT_SCHEMA()
|
||||
AND table_name = '%s'
|
||||
)
|
||||
`, m.tableName)).Scan(&pgExists).Error
|
||||
if err == nil {
|
||||
exists = pgExists
|
||||
} else {
|
||||
var sqliteCount int64
|
||||
err = m.db.Raw(fmt.Sprintf(`
|
||||
SELECT COUNT(*) FROM sqlite_master
|
||||
WHERE type='table' AND name='%s'
|
||||
`, m.tableName)).Scan(&sqliteCount).Error
|
||||
if err == nil && sqliteCount > 0 {
|
||||
exists = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果查询失败,假设表不存在,尝试创建
|
||||
if err != nil {
|
||||
// 如果查询失败,可能是SQLite或其他数据库,尝试直接创建
|
||||
exists = false
|
||||
}
|
||||
|
||||
@@ -89,19 +169,57 @@ func (m *Migrator) initTable() error {
|
||||
// 注意:这个检查可能在某些数据库中失败,但不影响功能
|
||||
// 如果字段不存在,记录执行时间时会失败,但不影响迁移执行
|
||||
var hasExecutionTime bool
|
||||
checkSQL := fmt.Sprintf(`
|
||||
SELECT COUNT(*) > 0
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = CURRENT_SCHEMA()
|
||||
AND table_name = '%s'
|
||||
AND column_name = 'execution_time'
|
||||
`, m.tableName)
|
||||
err = m.db.Raw(checkSQL).Scan(&hasExecutionTime).Error
|
||||
if err == nil && !hasExecutionTime {
|
||||
var columnCount int64
|
||||
var checkErr error
|
||||
|
||||
switch m.dbType {
|
||||
case "mysql":
|
||||
// MySQL/MariaDB语法
|
||||
checkErr = m.db.Raw(fmt.Sprintf(`
|
||||
SELECT COUNT(*)
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '%s'
|
||||
AND column_name = 'execution_time'
|
||||
`, m.tableName)).Scan(&columnCount).Error
|
||||
if checkErr == nil {
|
||||
hasExecutionTime = columnCount > 0
|
||||
}
|
||||
case "postgres":
|
||||
// PostgreSQL语法
|
||||
checkErr = m.db.Raw(fmt.Sprintf(`
|
||||
SELECT COUNT(*)
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = CURRENT_SCHEMA()
|
||||
AND table_name = '%s'
|
||||
AND column_name = 'execution_time'
|
||||
`, m.tableName)).Scan(&columnCount).Error
|
||||
if checkErr == nil {
|
||||
hasExecutionTime = columnCount > 0
|
||||
}
|
||||
case "sqlite":
|
||||
// SQLite不支持information_schema,跳过检查
|
||||
hasExecutionTime = false
|
||||
default:
|
||||
// 兼容模式:尝试MySQL语法
|
||||
checkErr = m.db.Raw(fmt.Sprintf(`
|
||||
SELECT COUNT(*)
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '%s'
|
||||
AND column_name = 'execution_time'
|
||||
`, m.tableName)).Scan(&columnCount).Error
|
||||
if checkErr == nil {
|
||||
hasExecutionTime = columnCount > 0
|
||||
}
|
||||
}
|
||||
|
||||
if !hasExecutionTime {
|
||||
// 尝试添加字段(如果失败不影响功能)
|
||||
// 注意:SQLite的ALTER TABLE ADD COLUMN语法略有不同,但GORM会处理
|
||||
_ = m.db.Exec(fmt.Sprintf(`
|
||||
ALTER TABLE %s
|
||||
ADD COLUMN execution_time INT COMMENT '执行耗时(ms)'
|
||||
ADD COLUMN execution_time INT
|
||||
`, m.tableName))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user