初始版本,工具基础类

This commit is contained in:
2025-11-30 13:02:34 +08:00
commit ea4e2e305d
37 changed files with 7480 additions and 0 deletions

373
migration/migration.go Normal file
View File

@@ -0,0 +1,373 @@
package migration
import (
"fmt"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"time"
"gorm.io/gorm"
)
// Migration 表示一个数据库迁移
type Migration struct {
Version string
Description string
Up func(*gorm.DB) error
Down func(*gorm.DB) error
}
// Migrator 数据库迁移器
type Migrator struct {
db *gorm.DB
migrations []Migration
tableName string
}
// NewMigrator 创建新的迁移器
// db: GORM数据库连接
// tableName: 存储迁移记录的表名,默认为 "schema_migrations"
func NewMigrator(db *gorm.DB, tableName ...string) *Migrator {
table := "schema_migrations"
if len(tableName) > 0 && tableName[0] != "" {
table = tableName[0]
}
return &Migrator{
db: db,
migrations: make([]Migration, 0),
tableName: table,
}
}
// AddMigration 添加迁移
func (m *Migrator) AddMigration(migration Migration) {
m.migrations = append(m.migrations, migration)
}
// AddMigrations 批量添加迁移
func (m *Migrator) AddMigrations(migrations ...Migration) {
m.migrations = append(m.migrations, migrations...)
}
// initTable 初始化迁移记录表
func (m *Migrator) initTable() error {
// 检查表是否存在
var exists bool
err := m.db.Raw(fmt.Sprintf(`
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_schema = CURRENT_SCHEMA()
AND table_name = '%s'
)
`, m.tableName)).Scan(&exists).Error
if err != nil {
// 如果查询失败可能是SQLite或其他数据库尝试直接创建
exists = false
}
if !exists {
// 创建迁移记录表
err = m.db.Exec(fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
version VARCHAR(255) PRIMARY KEY,
description VARCHAR(255),
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
`, m.tableName)).Error
if err != nil {
return fmt.Errorf("failed to create migration table: %w", err)
}
}
return nil
}
// getAppliedMigrations 获取已应用的迁移版本
func (m *Migrator) getAppliedMigrations() (map[string]bool, error) {
if err := m.initTable(); err != nil {
return nil, err
}
var versions []string
err := m.db.Table(m.tableName).Select("version").Pluck("version", &versions).Error
if err != nil {
return nil, fmt.Errorf("failed to get applied migrations: %w", err)
}
applied := make(map[string]bool)
for _, v := range versions {
applied[v] = true
}
return applied, nil
}
// recordMigration 记录迁移
func (m *Migrator) recordMigration(version, description string, isUp bool) error {
if err := m.initTable(); err != nil {
return err
}
if isUp {
// 记录迁移
err := m.db.Exec(fmt.Sprintf(`
INSERT INTO %s (version, description, applied_at)
VALUES (?, ?, ?)
`, m.tableName), version, description, time.Now()).Error
if err != nil {
return fmt.Errorf("failed to record migration: %w", err)
}
} else {
// 删除迁移记录
err := m.db.Exec(fmt.Sprintf("DELETE FROM %s WHERE version = ?", m.tableName), version).Error
if err != nil {
return fmt.Errorf("failed to remove migration record: %w", err)
}
}
return nil
}
// Up 执行所有未应用的迁移
func (m *Migrator) Up() error {
if err := m.initTable(); err != nil {
return err
}
applied, err := m.getAppliedMigrations()
if err != nil {
return err
}
// 排序迁移
sort.Slice(m.migrations, func(i, j int) bool {
return m.migrations[i].Version < m.migrations[j].Version
})
// 执行未应用的迁移
for _, migration := range m.migrations {
if applied[migration.Version] {
continue
}
if migration.Up == nil {
return fmt.Errorf("migration %s has no Up function", migration.Version)
}
// 开始事务
tx := m.db.Begin()
if tx.Error != nil {
return fmt.Errorf("failed to begin transaction: %w", tx.Error)
}
// 执行迁移
if err := migration.Up(tx); err != nil {
tx.Rollback()
return fmt.Errorf("failed to apply migration %s: %w", migration.Version, err)
}
// 记录迁移
if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, true); err != nil {
tx.Rollback()
return err
}
// 提交事务
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("failed to commit migration %s: %w", migration.Version, err)
}
fmt.Printf("Applied migration: %s - %s\n", migration.Version, migration.Description)
}
return nil
}
// Down 回滚最后一个迁移
func (m *Migrator) Down() error {
if err := m.initTable(); err != nil {
return err
}
applied, err := m.getAppliedMigrations()
if err != nil {
return err
}
// 排序迁移(倒序)
sort.Slice(m.migrations, func(i, j int) bool {
return m.migrations[i].Version > m.migrations[j].Version
})
// 找到最后一个已应用的迁移并回滚
for _, migration := range m.migrations {
if !applied[migration.Version] {
continue
}
if migration.Down == nil {
return fmt.Errorf("migration %s has no Down function", migration.Version)
}
// 开始事务
tx := m.db.Begin()
if tx.Error != nil {
return fmt.Errorf("failed to begin transaction: %w", tx.Error)
}
// 执行回滚
if err := migration.Down(tx); err != nil {
tx.Rollback()
return fmt.Errorf("failed to rollback migration %s: %w", migration.Version, err)
}
// 删除迁移记录
if err := m.recordMigrationWithDB(tx, migration.Version, migration.Description, false); err != nil {
tx.Rollback()
return err
}
// 提交事务
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("failed to commit rollback %s: %w", migration.Version, err)
}
fmt.Printf("Rolled back migration: %s - %s\n", migration.Version, migration.Description)
return nil
}
return fmt.Errorf("no migrations to rollback")
}
// recordMigrationWithDB 使用指定的数据库连接记录迁移
func (m *Migrator) recordMigrationWithDB(db *gorm.DB, version, description string, isUp bool) error {
if isUp {
err := db.Exec(fmt.Sprintf(`
INSERT INTO %s (version, description, applied_at)
VALUES (?, ?, ?)
`, m.tableName), version, description, time.Now()).Error
if err != nil {
return fmt.Errorf("failed to record migration: %w", err)
}
} else {
err := db.Exec(fmt.Sprintf("DELETE FROM %s WHERE version = ?", m.tableName), version).Error
if err != nil {
return fmt.Errorf("failed to remove migration record: %w", err)
}
}
return nil
}
// Status 查看迁移状态
func (m *Migrator) Status() ([]MigrationStatus, error) {
if err := m.initTable(); err != nil {
return nil, err
}
applied, err := m.getAppliedMigrations()
if err != nil {
return nil, err
}
// 排序迁移
sort.Slice(m.migrations, func(i, j int) bool {
return m.migrations[i].Version < m.migrations[j].Version
})
status := make([]MigrationStatus, 0, len(m.migrations))
for _, migration := range m.migrations {
status = append(status, MigrationStatus{
Version: migration.Version,
Description: migration.Description,
Applied: applied[migration.Version],
})
}
return status, nil
}
// MigrationStatus 迁移状态
type MigrationStatus struct {
Version string
Description string
Applied bool
}
// LoadMigrationsFromFiles 从文件系统加载迁移文件
// dir: 迁移文件目录
// pattern: 文件命名模式,例如 "*.sql" 或 "*.up.sql"
// 文件命名格式: {version}_{description}.sql 或 {version}_{description}.up.sql
func LoadMigrationsFromFiles(dir string, pattern string) ([]Migration, error) {
files, err := filepath.Glob(filepath.Join(dir, pattern))
if err != nil {
return nil, fmt.Errorf("failed to glob migration files: %w", err)
}
migrations := make([]Migration, 0)
for _, file := range files {
baseName := filepath.Base(file)
// 移除扩展名
nameWithoutExt := strings.TrimSuffix(baseName, filepath.Ext(baseName))
// 移除 .up 后缀(如果存在)
nameWithoutExt = strings.TrimSuffix(nameWithoutExt, ".up")
// 解析版本号和描述
parts := strings.SplitN(nameWithoutExt, "_", 2)
if len(parts) < 2 {
return nil, fmt.Errorf("invalid migration file name format: %s (expected: {version}_{description})", baseName)
}
version := parts[0]
description := strings.Join(parts[1:], "_")
// 读取文件内容
content, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("failed to read migration file %s: %w", file, err)
}
sqlContent := string(content)
// 查找对应的 down 文件
downFile := strings.Replace(file, ".up.sql", ".down.sql", 1)
downFile = strings.Replace(downFile, ".sql", ".down.sql", 1)
var downSQL string
if downContent, err := os.ReadFile(downFile); err == nil {
downSQL = string(downContent)
}
migration := Migration{
Version: version,
Description: description,
Up: func(db *gorm.DB) error {
return db.Exec(sqlContent).Error
},
}
if downSQL != "" {
migration.Down = func(db *gorm.DB) error {
return db.Exec(downSQL).Error
}
}
migrations = append(migrations, migration)
}
// 按版本号排序
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].Version < migrations[j].Version
})
return migrations, nil
}
// GenerateVersion 生成迁移版本号(基于时间戳)
func GenerateVersion() string {
return strconv.FormatInt(time.Now().Unix(), 10)
}