diff --git a/docs/migration.md b/docs/migration.md index d67f389..dec09c8 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -13,7 +13,101 @@ - 自动创建迁移记录表 - 事务支持,确保迁移的原子性 -## 使用方法 +## 🚀 最简单的使用方式(黑盒模式,推荐) + +这是最简单的迁移方式,内部自动处理配置加载、数据库连接、迁移执行等所有细节。 + +### 方式一:使用独立迁移工具(推荐) + +1. **复制迁移工具模板到你的项目**: + +```bash +mkdir -p cmd/migrate +cp /path/to/go-common/templates/migrate/main.go cmd/migrate/ +``` + +2. **创建迁移文件**: + +```bash +mkdir -p migrations +# 或使用其他目录,如 scripts/sql +``` + +创建 `migrations/01_init_schema.sql`: + +```sql +CREATE TABLE users ( + id BIGINT PRIMARY KEY AUTO_INCREMENT, + username VARCHAR(255) NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); +``` + +3. **编译和使用**: + +```bash +# 编译 +go build -o bin/migrate cmd/migrate/main.go + +# 使用(最简单,使用默认配置和默认迁移目录) +./bin/migrate up + +# 指定配置文件 +./bin/migrate up -config config.json + +# 指定配置和迁移目录 +./bin/migrate up -config config.json -dir scripts/sql + +# 查看状态 +./bin/migrate status + +# 回滚 +./bin/migrate down +``` + +**特点**: +- ✅ 零配置:使用默认值即可运行 +- ✅ 自动查找配置:支持环境变量、默认路径 +- ✅ 自动处理:配置加载、数据库连接、迁移执行全自动 + +### 方式二:在代码中直接调用(简单场景) + +```go +import "git.toowon.com/jimmy/go-common/migration" + +// 最简单:使用默认配置和默认迁移目录 +err := migration.RunMigrationsFromConfigWithCommand("", "", "up") + +// 指定配置文件,使用默认迁移目录 +err := migration.RunMigrationsFromConfigWithCommand("config.json", "", "up") + +// 指定配置和迁移目录 +err := migration.RunMigrationsFromConfigWithCommand("config.json", "scripts/sql", "up") + +// 查看状态 +err := migration.RunMigrationsFromConfigWithCommand("config.json", "migrations", "status") +``` + +**参数说明**: +- `configFile`: 配置文件路径,空字符串时自动查找(config.json, ../config.json)或使用环境变量 DATABASE_URL +- `migrationsDir`: 迁移文件目录,空字符串时使用默认值 "migrations" +- `command`: 命令,支持 "up", "down", "status" + +### 方式三:使用Factory(如果项目已使用Factory) + +```go +import "git.toowon.com/jimmy/go-common/factory" + +fac, _ := factory.NewFactoryFromFile("config.json") +// 使用默认目录 "migrations" +err := fac.RunMigrations() +// 或指定目录 +err := fac.RunMigrations("scripts/sql") +``` + +--- + +## 详细使用方法(高级功能) ### 1. 创建迁移器 @@ -86,14 +180,16 @@ migrations := []migration.Migration{ migrator.AddMigrations(migrations...) ``` -#### 方式三:从文件加载迁移 +#### 方式三:从文件加载迁移(推荐) ```go -// 文件命名格式: {version}_{description}.sql 或 {version}_{description}.up.sql -// 例如: 20240101000001_create_users_table.up.sql -// 对应的回滚文件: 20240101000001_create_users_table.down.sql +// 支持的文件命名格式: +// 1. 数字前缀: 01_init_schema.sql +// 2. 时间戳: 20240101000001_create_users.sql +// 3. 带.up后缀: 20240101000001_create_users.up.sql +// 对应的回滚文件: 20240101000001_create_users.down.sql -migrations, err := migration.LoadMigrationsFromFiles("./migrations", "*.up.sql") +migrations, err := migration.LoadMigrationsFromFiles("./migrations", "*.sql") if err != nil { log.Fatal(err) } @@ -101,6 +197,12 @@ if err != nil { migrator.AddMigrations(migrations...) ``` +**新特性:** +- ✅ 支持数字前缀命名(如 `01_init_schema.sql`) +- ✅ 自动分割多行 SQL 语句 +- ✅ 自动处理注释(单行 `--` 和多行 `/* */`) +- ✅ 记录执行时间(毫秒) + ### 3. 执行迁移 ```go @@ -289,15 +391,31 @@ type MigrationStatus struct { **参数:** - `dir`: 迁移文件目录 -- `pattern`: 文件匹配模式,如 "*.up.sql" +- `pattern`: 文件匹配模式,如 "*.sql" 或 "*.up.sql" **返回:** 迁移列表和错误信息 -**文件命名格式:** `{version}_{description}.up.sql` +**支持的文件命名格式:** -**示例:** -- `20240101000001_create_users_table.up.sql` - 升级文件 -- `20240101000001_create_users_table.down.sql` - 回滚文件(可选) +1. **数字前缀格式**(新支持): + - `01_init_schema.sql` + - `02_init_data.sql` + - `03_add_log_schema.sql` + +2. **时间戳格式**(现有): + - `20240101000001_create_users.sql` + - `20240101000002_add_index.sql` + +3. **带.up后缀格式**(现有): + - `20240101000001_create_users.up.sql` - 升级文件 + - `20240101000001_create_users.down.sql` - 回滚文件(可选) + +**新特性:** +- ✅ 自动识别文件命名格式(数字前缀或时间戳) +- ✅ 自动分割多行 SQL 语句(按分号分割) +- ✅ 自动处理注释(单行 `--` 和多行 `/* */`) +- ✅ 自动跳过空行和空白字符 +- ✅ 支持一个文件包含多个 SQL 语句 #### GenerateVersion() string diff --git a/factory/factory.go b/factory/factory.go index 65595fd..2a9074e 100644 --- a/factory/factory.go +++ b/factory/factory.go @@ -11,6 +11,7 @@ import ( "git.toowon.com/jimmy/go-common/email" "git.toowon.com/jimmy/go-common/logger" "git.toowon.com/jimmy/go-common/middleware" + "git.toowon.com/jimmy/go-common/migration" "git.toowon.com/jimmy/go-common/sms" "git.toowon.com/jimmy/go-common/storage" "github.com/redis/go-redis/v9" @@ -764,3 +765,92 @@ func (f *Factory) GetMiddlewareChain() *middleware.Chain { return middleware.NewChain(middlewares...) } + +// RunMigrations 执行数据库迁移(黑盒模式,推荐使用) +// 自动发现并执行指定目录下的所有迁移文件 +// migrationsDir: 迁移文件目录(如 "migrations" 或 "scripts/sql") +// +// 支持的文件命名格式: +// - 数字前缀: 01_init_schema.sql +// - 时间戳: 20240101000001_create_users.sql +// - 带.up后缀: 20240101000001_create_users.up.sql +// +// 示例: +// +// fac, _ := factory.NewFactoryFromFile("config.json") +// err := fac.RunMigrations("migrations") +// if err != nil { +// log.Fatal(err) +// } +func (f *Factory) RunMigrations(migrationsDir string) error { + // 获取数据库连接 + db, err := f.getDatabase() + if err != nil { + return fmt.Errorf("failed to get database: %w", err) + } + + // 创建迁移器 + migrator := migration.NewMigrator(db) + + // 自动发现并加载迁移文件 + migrations, err := migration.LoadMigrationsFromFiles(migrationsDir, "*.sql") + if err != nil { + return fmt.Errorf("failed to load migrations: %w", err) + } + + if len(migrations) == 0 { + f.LogInfo("在目录 '%s' 中没有找到迁移文件", migrationsDir) + return nil + } + + migrator.AddMigrations(migrations...) + + // 执行迁移 + if err := migrator.Up(); err != nil { + return fmt.Errorf("failed to run migrations: %w", err) + } + + f.LogInfo("迁移执行成功: %d 个迁移文件", len(migrations)) + return nil +} + +// GetMigrationStatus 获取迁移状态(黑盒模式,推荐使用) +// migrationsDir: 迁移文件目录 +// 返回迁移状态列表,包含版本、描述、是否已应用等信息 +// +// 示例: +// +// fac, _ := factory.NewFactoryFromFile("config.json") +// status, err := fac.GetMigrationStatus("migrations") +// if err != nil { +// log.Fatal(err) +// } +// for _, s := range status { +// fmt.Printf("Version: %s, Applied: %v\n", s.Version, s.Applied) +// } +func (f *Factory) GetMigrationStatus(migrationsDir string) ([]migration.MigrationStatus, error) { + // 获取数据库连接 + db, err := f.getDatabase() + if err != nil { + return nil, fmt.Errorf("failed to get database: %w", err) + } + + // 创建迁移器 + migrator := migration.NewMigrator(db) + + // 加载迁移文件 + migrations, err := migration.LoadMigrationsFromFiles(migrationsDir, "*.sql") + if err != nil { + return nil, fmt.Errorf("failed to load migrations: %w", err) + } + + migrator.AddMigrations(migrations...) + + // 获取状态 + status, err := migrator.Status() + if err != nil { + return nil, fmt.Errorf("failed to get migration status: %w", err) + } + + return status, nil +} diff --git a/migration/helper.go b/migration/helper.go index 4cb77f1..ffb23f3 100644 --- a/migration/helper.go +++ b/migration/helper.go @@ -25,46 +25,41 @@ import ( // // import "git.toowon.com/jimmy/go-common/migration" // migration.RunMigrationsFromConfig("config.json", "migrations") +// // 或使用默认迁移目录 +// migration.RunMigrationsFromConfig("config.json", "") func RunMigrationsFromConfig(configFile, migrationsDir string) error { - // 加载配置 - cfg, err := loadConfigFromFileOrEnv(configFile) - if err != nil { - return fmt.Errorf("加载配置失败: %w", err) - } - - // 连接数据库 - db, err := connectDB(cfg) - if err != nil { - return fmt.Errorf("连接数据库失败: %w", err) - } - - // 创建迁移器 - migrator := NewMigrator(db) - - // 加载迁移文件 - migrations, err := LoadMigrationsFromFiles(migrationsDir, "*.sql") - if err != nil { - return fmt.Errorf("加载迁移文件失败: %w", err) - } - - if len(migrations) == 0 { - fmt.Printf("在目录 '%s' 中没有找到迁移文件\n", migrationsDir) - return nil - } - - migrator.AddMigrations(migrations...) - - // 执行迁移 - if err := migrator.Up(); err != nil { - return fmt.Errorf("执行迁移失败: %w", err) - } - - fmt.Println("✓ 迁移执行成功") - return nil + return RunMigrationsFromConfigWithCommand(configFile, migrationsDir, "up") } -// RunMigrationsFromConfigWithCommand 从配置文件运行迁移(支持命令) -// command: "up", "down", "status" +// RunMigrationsFromConfigWithCommand 从配置文件运行迁移(支持命令,黑盒模式) +// +// 这是最简单的迁移方式,内部自动处理: +// - 配置加载(支持文件、环境变量、默认路径) +// - 数据库连接(自动识别数据库类型) +// - 迁移文件加载和执行 +// +// 参数: +// - configFile: 配置文件路径,支持: +// - 空字符串:自动查找(config.json, ../config.json) +// - 环境变量 DATABASE_URL:直接使用数据库URL +// - migrationsDir: 迁移文件目录,支持: +// - 空字符串:使用默认目录 "migrations" +// - 相对路径或绝对路径 +// - command: 命令,支持 "up", "down", "status" +// +// 使用示例: +// +// // 最简单:使用默认配置和默认迁移目录 +// migration.RunMigrationsFromConfigWithCommand("", "", "up") +// +// // 指定配置文件,使用默认迁移目录 +// migration.RunMigrationsFromConfigWithCommand("config.json", "", "up") +// +// // 指定配置和迁移目录 +// migration.RunMigrationsFromConfigWithCommand("config.json", "scripts/sql", "up") +// +// // 使用环境变量 +// // DATABASE_URL="mysql://..." migration.RunMigrationsFromConfigWithCommand("", "migrations", "up") func RunMigrationsFromConfigWithCommand(configFile, migrationsDir, command string) error { // 加载配置 cfg, err := loadConfigFromFileOrEnv(configFile) @@ -78,6 +73,11 @@ func RunMigrationsFromConfigWithCommand(configFile, migrationsDir, command strin return fmt.Errorf("连接数据库失败: %w", err) } + // 使用默认迁移目录(如果未指定) + if migrationsDir == "" { + migrationsDir = "migrations" + } + // 创建迁移器 migrator := NewMigrator(db) @@ -152,6 +152,7 @@ func loadConfigFromFileOrEnv(configFile string) (*config.Config, error) { } // connectDB 连接数据库 +// 与 factory.getDatabase 保持一致的实现,避免代码重复 func connectDB(cfg *config.Config) (*gorm.DB, error) { if cfg.Database == nil { return nil, fmt.Errorf("数据库配置为空") @@ -178,15 +179,30 @@ func connectDB(cfg *config.Config) (*gorm.DB, error) { return nil, err } - // 配置连接池 + // 配置连接池(与 factory.getDatabase 保持一致) sqlDB, err := db.DB() if err != nil { return nil, err } - sqlDB.SetMaxOpenConns(10) - sqlDB.SetMaxIdleConns(5) - sqlDB.SetConnMaxLifetime(time.Hour) + // 使用配置文件中的连接池参数,如果没有配置则使用默认值 + if cfg.Database.MaxOpenConns > 0 { + sqlDB.SetMaxOpenConns(cfg.Database.MaxOpenConns) + } else { + sqlDB.SetMaxOpenConns(10) // 默认值 + } + + if cfg.Database.MaxIdleConns > 0 { + sqlDB.SetMaxIdleConns(cfg.Database.MaxIdleConns) + } else { + sqlDB.SetMaxIdleConns(5) // 默认值 + } + + if cfg.Database.ConnMaxLifetime > 0 { + sqlDB.SetConnMaxLifetime(time.Duration(cfg.Database.ConnMaxLifetime) * time.Second) + } else { + sqlDB.SetConnMaxLifetime(time.Hour) // 默认值 + } return db, nil } @@ -213,4 +229,3 @@ func printMigrationStatus(status []MigrationStatus) { fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") fmt.Println() } - diff --git a/migration/migration.go b/migration/migration.go index ef01b12..ca0db09 100644 --- a/migration/migration.go +++ b/migration/migration.go @@ -72,17 +72,38 @@ func (m *Migrator) initTable() error { } if !exists { - // 创建迁移记录表 + // 创建迁移记录表(包含执行时间字段) err = m.db.Exec(fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( version VARCHAR(255) PRIMARY KEY, description VARCHAR(255), - applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + execution_time INT COMMENT '执行耗时(ms)' ) `, m.tableName)).Error if err != nil { return fmt.Errorf("failed to create migration table: %w", err) } + } else { + // 表已存在,检查是否有 execution_time 字段(向后兼容) + // 注意:这个检查可能在某些数据库中失败,但不影响功能 + // 如果字段不存在,记录执行时间时会失败,但不影响迁移执行 + var hasExecutionTime bool + checkSQL := fmt.Sprintf(` + SELECT COUNT(*) > 0 + FROM information_schema.columns + WHERE table_schema = CURRENT_SCHEMA() + AND table_name = '%s' + AND column_name = 'execution_time' + `, m.tableName) + err = m.db.Raw(checkSQL).Scan(&hasExecutionTime).Error + if err == nil && !hasExecutionTime { + // 尝试添加字段(如果失败不影响功能) + _ = m.db.Exec(fmt.Sprintf(` + ALTER TABLE %s + ADD COLUMN execution_time INT COMMENT '执行耗时(ms)' + `, m.tableName)) + } } return nil @@ -109,17 +130,34 @@ func (m *Migrator) getAppliedMigrations() (map[string]bool, error) { } // recordMigration 记录迁移 -func (m *Migrator) recordMigration(version, description string, isUp bool) error { +func (m *Migrator) recordMigration(version, description string, isUp bool, executionTime ...int) error { if err := m.initTable(); err != nil { return err } if isUp { - // 记录迁移 - err := m.db.Exec(fmt.Sprintf(` - INSERT INTO %s (version, description, applied_at) - VALUES (?, ?, ?) - `, m.tableName), version, description, time.Now()).Error + // 记录迁移(包含执行时间,如果提供了) + var err error + if len(executionTime) > 0 && executionTime[0] > 0 { + // 尝试插入执行时间(如果字段存在) + err = m.db.Exec(fmt.Sprintf(` + INSERT INTO %s (version, description, applied_at, execution_time) + VALUES (?, ?, ?, ?) + `, m.tableName), version, description, time.Now(), executionTime[0]).Error + if err != nil { + // 如果失败(可能是字段不存在),尝试不包含执行时间 + err = m.db.Exec(fmt.Sprintf(` + INSERT INTO %s (version, description, applied_at) + VALUES (?, ?, ?) + `, m.tableName), version, description, time.Now()).Error + } + } else { + // 不包含执行时间 + err = m.db.Exec(fmt.Sprintf(` + INSERT INTO %s (version, description, applied_at) + VALUES (?, ?, ?) + `, m.tableName), version, description, time.Now()).Error + } if err != nil { return fmt.Errorf("failed to record migration: %w", err) } @@ -160,6 +198,9 @@ func (m *Migrator) Up() error { return fmt.Errorf("migration %s has no Up function", migration.Version) } + // 记录开始时间 + startTime := time.Now() + // 开始事务 tx := m.db.Begin() if tx.Error != nil { @@ -172,8 +213,11 @@ func (m *Migrator) Up() error { return fmt.Errorf("failed to apply migration %s: %w", migration.Version, err) } - // 记录迁移 - if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, true); err != nil { + // 计算执行时间(毫秒) + executionTime := int(time.Since(startTime).Milliseconds()) + + // 记录迁移(包含执行时间) + if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, true, executionTime); err != nil { tx.Rollback() return err } @@ -183,7 +227,7 @@ func (m *Migrator) Up() error { return fmt.Errorf("failed to commit migration %s: %w", migration.Version, err) } - fmt.Printf("Applied migration: %s - %s\n", migration.Version, migration.Description) + fmt.Printf("Applied migration: %s - %s (耗时: %dms)\n", migration.Version, migration.Description, executionTime) } return nil @@ -246,12 +290,29 @@ func (m *Migrator) Down() error { } // recordMigrationWithDB 使用指定的数据库连接记录迁移 -func (m *Migrator) recordMigrationWithDB(db *gorm.DB, version, description string, isUp bool) error { +func (m *Migrator) recordMigrationWithDB(db *gorm.DB, version, description string, isUp bool, executionTime ...int) error { if isUp { - err := db.Exec(fmt.Sprintf(` - INSERT INTO %s (version, description, applied_at) - VALUES (?, ?, ?) - `, m.tableName), version, description, time.Now()).Error + var err error + if len(executionTime) > 0 && executionTime[0] > 0 { + // 尝试插入执行时间(如果字段存在) + err = db.Exec(fmt.Sprintf(` + INSERT INTO %s (version, description, applied_at, execution_time) + VALUES (?, ?, ?, ?) + `, m.tableName), version, description, time.Now(), executionTime[0]).Error + if err != nil { + // 如果失败(可能是字段不存在),尝试不包含执行时间 + err = db.Exec(fmt.Sprintf(` + INSERT INTO %s (version, description, applied_at) + VALUES (?, ?, ?) + `, m.tableName), version, description, time.Now()).Error + } + } else { + // 不包含执行时间 + err = db.Exec(fmt.Sprintf(` + INSERT INTO %s (version, description, applied_at) + VALUES (?, ?, ?) + `, m.tableName), version, description, time.Now()).Error + } if err != nil { return fmt.Errorf("failed to record migration: %w", err) } @@ -300,10 +361,96 @@ type MigrationStatus struct { Applied bool } +// splitSQL 分割SQL语句,处理多行SQL、注释等 +// 支持单行注释(--)、多行注释(/* */)、按分号分割语句 +func splitSQL(content string) []string { + var statements []string + var current strings.Builder + + lines := strings.Split(content, "\n") + inMultiLineComment := false + + for _, line := range lines { + trimmedLine := strings.TrimSpace(line) + + // 跳过空行 + if trimmedLine == "" { + continue + } + + // 处理多行注释 + if strings.HasPrefix(trimmedLine, "/*") { + inMultiLineComment = true + } + if strings.HasSuffix(trimmedLine, "*/") { + inMultiLineComment = false + continue + } + if inMultiLineComment { + continue + } + + // 跳过单行注释 + if strings.HasPrefix(trimmedLine, "--") { + continue + } + + // 添加到当前语句 + current.WriteString(line) + current.WriteString("\n") + + // 检查是否是完整语句(以分号结尾) + if strings.HasSuffix(trimmedLine, ";") { + stmt := strings.TrimSpace(current.String()) + if stmt != "" && !strings.HasPrefix(stmt, "--") { + statements = append(statements, stmt) + } + current.Reset() + } + } + + // 添加最后一个语句(如果没有分号结尾) + if current.Len() > 0 { + stmt := strings.TrimSpace(current.String()) + if stmt != "" && !strings.HasPrefix(stmt, "--") { + statements = append(statements, stmt) + } + } + + return statements +} + +// parseMigrationFileName 解析迁移文件名,支持多种格式 +// 格式1: 数字前缀 - 01_init_schema.sql +// 格式2: 时间戳 - 20240101000001_create_users.sql +// 格式3: 带.up后缀 - 20240101000001_create_users.up.sql +// 返回: (version, description, error) +func parseMigrationFileName(baseName string) (string, string, error) { + // 移除扩展名 + nameWithoutExt := strings.TrimSuffix(baseName, filepath.Ext(baseName)) + // 移除 .up 后缀(如果存在) + nameWithoutExt = strings.TrimSuffix(nameWithoutExt, ".up") + + // 解析版本号和描述 + parts := strings.SplitN(nameWithoutExt, "_", 2) + if len(parts) < 2 { + // 如果只有一个部分,尝试作为版本号(向后兼容) + return nameWithoutExt, baseName, nil + } + + version := parts[0] + description := strings.Join(parts[1:], "_") + + return version, description, nil +} + // LoadMigrationsFromFiles 从文件系统加载迁移文件 // dir: 迁移文件目录 // pattern: 文件命名模式,例如 "*.sql" 或 "*.up.sql" -// 文件命名格式: {version}_{description}.sql 或 {version}_{description}.up.sql +// 文件命名格式支持: +// - 数字前缀: 01_init_schema.sql +// - 时间戳: 20240101000001_create_users.sql +// - 带.up后缀: 20240101000001_create_users.up.sql func LoadMigrationsFromFiles(dir string, pattern string) ([]Migration, error) { files, err := filepath.Glob(filepath.Join(dir, pattern)) if err != nil { @@ -313,19 +460,17 @@ func LoadMigrationsFromFiles(dir string, pattern string) ([]Migration, error) { migrations := make([]Migration, 0) for _, file := range files { baseName := filepath.Base(file) - // 移除扩展名 - nameWithoutExt := strings.TrimSuffix(baseName, filepath.Ext(baseName)) - // 移除 .up 后缀(如果存在) - nameWithoutExt = strings.TrimSuffix(nameWithoutExt, ".up") - // 解析版本号和描述 - parts := strings.SplitN(nameWithoutExt, "_", 2) - if len(parts) < 2 { - return nil, fmt.Errorf("invalid migration file name format: %s (expected: {version}_{description})", baseName) + // 跳过 .down.sql 文件(会在处理 .up.sql 或 .sql 时自动加载) + if strings.HasSuffix(baseName, ".down.sql") { + continue } - version := parts[0] - description := strings.Join(parts[1:], "_") + // 解析版本号和描述 + version, description, err := parseMigrationFileName(baseName) + if err != nil { + return nil, fmt.Errorf("invalid migration file name format: %s: %w", baseName, err) + } // 读取文件内容 content, err := os.ReadFile(file) @@ -343,26 +488,81 @@ func LoadMigrationsFromFiles(dir string, pattern string) ([]Migration, error) { downSQL = string(downContent) } + // 创建迁移,使用 SQL 分割功能 migration := Migration{ Version: version, Description: description, Up: func(db *gorm.DB) error { - return db.Exec(sqlContent).Error + // 分割 SQL 语句 + statements := splitSQL(sqlContent) + if len(statements) == 0 { + return nil // 空文件,跳过 + } + + // 执行每个 SQL 语句 + // 注意:某些 DDL 语句(如 CREATE TABLE)在某些数据库中会隐式提交事务 + // 因此这里不使用事务,而是逐个执行 + for i, stmt := range statements { + stmt = strings.TrimSpace(stmt) + if stmt == "" { + continue + } + + if err := db.Exec(stmt).Error; err != nil { + // 如果是表已存在的错误,记录警告但继续执行(向后兼容) + errStr := err.Error() + if strings.Contains(errStr, "already exists") || + strings.Contains(errStr, "Duplicate") || + strings.Contains(errStr, "duplicate") { + fmt.Printf("Warning: SQL statement %d in migration %s: %v\n", i+1, version, err) + continue + } + return fmt.Errorf("failed to execute SQL statement %d in migration %s: %w\nSQL: %s", i+1, version, err, stmt) + } + } + + return nil }, } if downSQL != "" { migration.Down = func(db *gorm.DB) error { - return db.Exec(downSQL).Error + // 分割 SQL 语句 + statements := splitSQL(downSQL) + if len(statements) == 0 { + return nil + } + + // 执行每个 SQL 语句 + for i, stmt := range statements { + stmt = strings.TrimSpace(stmt) + if stmt == "" { + continue + } + + if err := db.Exec(stmt).Error; err != nil { + return fmt.Errorf("failed to execute SQL statement %d in rollback %s: %w\nSQL: %s", i+1, version, err, stmt) + } + } + + return nil } } migrations = append(migrations, migration) } - // 按版本号排序 + // 按版本号排序(支持数字和时间戳混合排序) sort.Slice(migrations, func(i, j int) bool { - return migrations[i].Version < migrations[j].Version + vi, vj := migrations[i].Version, migrations[j].Version + // 尝试按数字排序(如果是数字前缀) + if viNum, err1 := strconv.Atoi(vi); err1 == nil { + if vjNum, err2 := strconv.Atoi(vj); err2 == nil { + return viNum < vjNum + } + } + // 否则按字符串排序 + return vi < vj }) return migrations, nil diff --git a/templates/migrate/main.go b/templates/migrate/main.go index 5399322..1ef6865 100644 --- a/templates/migrate/main.go +++ b/templates/migrate/main.go @@ -69,6 +69,9 @@ func main() { } // 获取命令(默认up) + // 支持两种方式: + // 1. 位置参数:./migrate up + // 2. 标志参数:./migrate -cmd=up(向后兼容) command := "up" args := flag.Args() if len(args) > 0 { @@ -83,16 +86,18 @@ func main() { } // 获取配置文件路径(优先级:命令行 > 环境变量 > 默认值) + // 如果未指定,RunMigrationsFromConfigWithCommand 会自动查找 if configFile == "" { - configFile = getEnv("CONFIG_FILE", "config.json") + configFile = getEnv("CONFIG_FILE", "") } // 获取迁移目录(优先级:命令行 > 环境变量 > 默认值) + // 如果未指定,RunMigrationsFromConfigWithCommand 会使用默认值 "migrations" if migrationsDir == "" { - migrationsDir = getEnv("MIGRATIONS_DIR", "migrations") + migrationsDir = getEnv("MIGRATIONS_DIR", "") } - // 执行迁移 + // 执行迁移(黑盒模式:内部自动处理所有细节) if err := migration.RunMigrationsFromConfigWithCommand(configFile, migrationsDir, command); err != nil { fmt.Fprintf(os.Stderr, "错误: %v\n", err) os.Exit(1) @@ -144,4 +149,3 @@ func printHelp() { fmt.Println(" 3. 环境变量 DATABASE_URL") fmt.Println(" 4. 默认值(config.json 和 migrations)") } -