package factory import ( "context" "fmt" "net/http" "sync" "time" "git.toowon.com/jimmy/go-common/config" "git.toowon.com/jimmy/go-common/email" "git.toowon.com/jimmy/go-common/excel" "git.toowon.com/jimmy/go-common/i18n" "git.toowon.com/jimmy/go-common/logger" "git.toowon.com/jimmy/go-common/middleware" "git.toowon.com/jimmy/go-common/migration" "git.toowon.com/jimmy/go-common/sms" "git.toowon.com/jimmy/go-common/storage" "github.com/redis/go-redis/v9" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/gorm" ) var ( defaultFactory *Factory ) // Factory 工具库入口:配置加载 + lazy getter type Factory struct { cfg *config.Config storage storage.Storage logger *logger.Logger email *email.Email sms *sms.SMS db *gorm.DB redis *redis.Client i18n *i18n.I18n excel *excel.Excel chain *middleware.Chain mu sync.Mutex } // Option Factory 可选项(支持重载模块实现) type Option func(*Factory) // WithStorage 注入自定义存储实现 func WithStorage(s storage.Storage) Option { return func(f *Factory) { f.storage = s } } // Init 从配置文件初始化全局 Factory(启动时调用一次) func Init(filePath string, opts ...Option) error { cfg, err := config.LoadFromFile(filePath) if err != nil { return err } defaultFactory = New(cfg, opts...) if l, err := defaultFactory.Logger(); err == nil { logger.SetDefaultLogger(l) } return nil } // Default 获取全局 Factory func Default() *Factory { return defaultFactory } // New 从配置创建 Factory func New(cfg *config.Config, opts ...Option) *Factory { f := &Factory{cfg: cfg} for _, opt := range opts { opt(f) } return f } // Config 获取配置 func (f *Factory) Config() *config.Config { return f.cfg } func (f *Factory) getLogger() (*logger.Logger, error) { if f.logger != nil { return f.logger, nil } f.mu.Lock() defer f.mu.Unlock() if f.logger != nil { return f.logger, nil } var cfg *config.LoggerConfig if f.cfg != nil { cfg = f.cfg.Logger } l, err := logger.NewLogger(cfg) if err != nil { return nil, fmt.Errorf("failed to create logger: %w", err) } f.logger = l return l, nil } // Logger 获取日志对象 func (f *Factory) Logger() (*logger.Logger, error) { return f.getLogger() } func (f *Factory) getDatabase() (*gorm.DB, error) { if f.db != nil { return f.db, nil } f.mu.Lock() defer f.mu.Unlock() if f.db != nil { return f.db, nil } if f.cfg == nil || f.cfg.Database == nil { return nil, fmt.Errorf("database config is nil") } dsn, err := f.cfg.GetDatabaseDSN() if err != nil { return nil, err } var db *gorm.DB switch f.cfg.Database.Type { case "mysql": db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{}) case "postgres": db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}) case "sqlite": db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{}) default: return nil, fmt.Errorf("unsupported database type: %s", f.cfg.Database.Type) } if err != nil { return nil, fmt.Errorf("failed to connect to database: %w", err) } sqlDB, err := db.DB() if err != nil { return nil, err } if f.cfg.Database.MaxOpenConns > 0 { sqlDB.SetMaxOpenConns(f.cfg.Database.MaxOpenConns) } if f.cfg.Database.MaxIdleConns > 0 { sqlDB.SetMaxIdleConns(f.cfg.Database.MaxIdleConns) } if f.cfg.Database.ConnMaxLifetime > 0 { sqlDB.SetConnMaxLifetime(time.Duration(f.cfg.Database.ConnMaxLifetime) * time.Second) } f.db = db return db, nil } // Database 获取 GORM 数据库连接 func (f *Factory) Database() (*gorm.DB, error) { return f.getDatabase() } func (f *Factory) getRedis() (*redis.Client, error) { if f.redis != nil { return f.redis, nil } f.mu.Lock() defer f.mu.Unlock() if f.redis != nil { return f.redis, nil } if f.cfg == nil || f.cfg.Redis == nil { return nil, fmt.Errorf("redis config is nil") } addr := f.cfg.GetRedisAddr() if addr == "" { return nil, fmt.Errorf("redis address is empty") } rc := f.cfg.Redis client := redis.NewClient(&redis.Options{ Addr: addr, Password: rc.Password, DB: rc.Database, PoolSize: rc.PoolSize, MinIdleConns: rc.MinIdleConns, MaxRetries: rc.MaxRetries, DialTimeout: time.Duration(rc.DialTimeout) * time.Second, ReadTimeout: time.Duration(rc.ReadTimeout) * time.Second, WriteTimeout: time.Duration(rc.WriteTimeout) * time.Second, }) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(rc.DialTimeout)*time.Second) defer cancel() if _, err := client.Ping(ctx).Result(); err != nil { _ = client.Close() return nil, fmt.Errorf("failed to connect to redis: %w", err) } f.redis = client return client, nil } // Redis 获取 Redis 客户端 func (f *Factory) Redis() (*redis.Client, error) { return f.getRedis() } func (f *Factory) getStorage() (storage.Storage, error) { if f.storage != nil { return f.storage, nil } f.mu.Lock() defer f.mu.Unlock() if f.storage != nil { return f.storage, nil } if f.cfg == nil { return nil, fmt.Errorf("config is nil") } var storageType storage.StorageType switch { case f.cfg.GetLocalStorage() != nil: storageType = storage.StorageTypeLocal case f.cfg.MinIO != nil: storageType = storage.StorageTypeMinIO case f.cfg.OSS != nil: storageType = storage.StorageTypeOSS default: return nil, fmt.Errorf("no storage config found (LocalStorage, OSS or MinIO)") } s, err := storage.NewStorage(storageType, f.cfg) if err != nil { return nil, err } f.storage = s return s, nil } // Storage 获取存储对象 func (f *Factory) Storage() (storage.Storage, error) { return f.getStorage() } func (f *Factory) getEmail() (*email.Email, error) { if f.email != nil { return f.email, nil } f.mu.Lock() defer f.mu.Unlock() if f.email != nil { return f.email, nil } if f.cfg == nil || f.cfg.Email == nil { return nil, fmt.Errorf("email config is nil") } f.email = email.NewEmail(f.cfg) return f.email, nil } // Email 获取邮件客户端 func (f *Factory) Email() (*email.Email, error) { return f.getEmail() } func (f *Factory) getSMS() (*sms.SMS, error) { if f.sms != nil { return f.sms, nil } f.mu.Lock() defer f.mu.Unlock() if f.sms != nil { return f.sms, nil } if f.cfg == nil || f.cfg.SMS == nil { return nil, fmt.Errorf("sms config is nil") } f.sms = sms.NewSMS(f.cfg) return f.sms, nil } // SMS 获取短信客户端 func (f *Factory) SMS() (*sms.SMS, error) { return f.getSMS() } func (f *Factory) getI18n() (*i18n.I18n, error) { if f.i18n != nil { return f.i18n, nil } f.mu.Lock() defer f.mu.Unlock() if f.i18n != nil { return f.i18n, nil } if f.cfg == nil || f.cfg.I18n == nil { return nil, fmt.Errorf("i18n config is nil") } i := i18n.NewI18n(f.cfg.I18n.DefaultLang) if f.cfg.I18n.LocalesDir != "" { if err := i.LoadFromDir(f.cfg.I18n.LocalesDir); err != nil { return nil, err } } f.i18n = i return i, nil } // I18n 获取国际化对象 func (f *Factory) I18n() (*i18n.I18n, error) { return f.getI18n() } func (f *Factory) getExcel() *excel.Excel { if f.excel != nil { return f.excel } f.mu.Lock() defer f.mu.Unlock() if f.excel == nil { f.excel = excel.NewExcel() } return f.excel } // Excel 获取 Excel 导出器 func (f *Factory) Excel() *excel.Excel { return f.getExcel() } // MiddlewareChain 获取默认中间件链 func (f *Factory) MiddlewareChain() *middleware.Chain { if f.chain != nil { return f.chain } f.mu.Lock() defer f.mu.Unlock() if f.chain != nil { return f.chain } var mws []func(http.Handler) http.Handler l, _ := f.getLogger() i18nInst, _ := f.getI18n() mws = append(mws, middleware.Recovery(&middleware.RecoveryConfig{ Logger: l, I18n: i18nInst, })) mws = append(mws, middleware.RequestID()) mws = append(mws, middleware.Logging(&middleware.LoggingConfig{Logger: l})) if f.cfg != nil && f.cfg.RateLimit != nil && f.cfg.RateLimit.Enable { limiter := middleware.NewTokenBucketLimiter( f.cfg.RateLimit.Rate, time.Duration(f.cfg.RateLimit.Period)*time.Second, ) var keyFunc func(r *http.Request) string if f.cfg.RateLimit.ByIP { keyFunc = func(r *http.Request) string { return middleware.GetClientIP(r) } } else if f.cfg.RateLimit.ByUserID { keyFunc = func(r *http.Request) string { return r.Header.Get("X-User-ID") } } mws = append(mws, middleware.RateLimit(&middleware.RateLimitConfig{ Limiter: limiter, KeyFunc: keyFunc, })) } if f.cfg != nil && f.cfg.CORS != nil { mws = append(mws, middleware.CORS(&middleware.CORSConfig{ AllowedOrigins: f.cfg.CORS.AllowedOrigins, AllowedMethods: f.cfg.CORS.AllowedMethods, AllowedHeaders: f.cfg.CORS.AllowedHeaders, ExposedHeaders: f.cfg.CORS.ExposedHeaders, AllowCredentials: f.cfg.CORS.AllowCredentials, MaxAge: f.cfg.CORS.MaxAge, })) } mws = append(mws, middleware.Language, middleware.Timezone) f.chain = middleware.NewChain(mws...) return f.chain } // Migrator 创建迁移器并加载指定目录下的 SQL 文件 func (f *Factory) Migrator(migrationsDir string) (*migration.Migrator, error) { db, err := f.getDatabase() if err != nil { return nil, err } dbType := "mysql" if f.cfg.Database != nil && f.cfg.Database.Type != "" { dbType = f.cfg.Database.Type } m := migration.NewMigratorWithType(db, dbType) migrations, err := migration.LoadMigrationsFromFiles(migrationsDir, "*.sql") if err != nil { return nil, err } m.AddMigrations(migrations...) return m, nil }