Files
go-common/factory/factory.go

405 lines
9.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package factory
import (
"context"
"fmt"
"net/http"
"sync"
"time"
"git.toowon.com/jimmy/go-common/config"
"git.toowon.com/jimmy/go-common/email"
"git.toowon.com/jimmy/go-common/excel"
"git.toowon.com/jimmy/go-common/i18n"
"git.toowon.com/jimmy/go-common/logger"
"git.toowon.com/jimmy/go-common/middleware"
"git.toowon.com/jimmy/go-common/migration"
"git.toowon.com/jimmy/go-common/sms"
"git.toowon.com/jimmy/go-common/storage"
"github.com/redis/go-redis/v9"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
var (
defaultFactory *Factory
)
// Factory 工具库入口:配置加载 + lazy getter
type Factory struct {
cfg *config.Config
storage storage.Storage
logger *logger.Logger
email *email.Email
sms *sms.SMS
db *gorm.DB
redis *redis.Client
i18n *i18n.I18n
excel *excel.Excel
chain *middleware.Chain
mu sync.Mutex
}
// Option Factory 可选项(支持重载模块实现)
type Option func(*Factory)
// WithStorage 注入自定义存储实现
func WithStorage(s storage.Storage) Option {
return func(f *Factory) {
f.storage = s
}
}
// Init 从配置文件初始化全局 Factory启动时调用一次
func Init(filePath string, opts ...Option) error {
cfg, err := config.LoadFromFile(filePath)
if err != nil {
return err
}
defaultFactory = New(cfg, opts...)
if l, err := defaultFactory.Logger(); err == nil {
logger.SetDefaultLogger(l)
}
return nil
}
// Default 获取全局 Factory
func Default() *Factory {
return defaultFactory
}
// New 从配置创建 Factory
func New(cfg *config.Config, opts ...Option) *Factory {
f := &Factory{cfg: cfg}
for _, opt := range opts {
opt(f)
}
return f
}
// Config 获取配置
func (f *Factory) Config() *config.Config {
return f.cfg
}
func (f *Factory) getLogger() (*logger.Logger, error) {
if f.logger != nil {
return f.logger, nil
}
f.mu.Lock()
defer f.mu.Unlock()
if f.logger != nil {
return f.logger, nil
}
var cfg *config.LoggerConfig
if f.cfg != nil {
cfg = f.cfg.Logger
}
l, err := logger.NewLogger(cfg)
if err != nil {
return nil, fmt.Errorf("failed to create logger: %w", err)
}
f.logger = l
return l, nil
}
// Logger 获取日志对象
func (f *Factory) Logger() (*logger.Logger, error) {
return f.getLogger()
}
func (f *Factory) getDatabase() (*gorm.DB, error) {
if f.db != nil {
return f.db, nil
}
f.mu.Lock()
defer f.mu.Unlock()
if f.db != nil {
return f.db, nil
}
if f.cfg == nil || f.cfg.Database == nil {
return nil, fmt.Errorf("database config is nil")
}
dsn, err := f.cfg.GetDatabaseDSN()
if err != nil {
return nil, err
}
var db *gorm.DB
switch f.cfg.Database.Type {
case "mysql":
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
case "postgres":
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
case "sqlite":
db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{})
default:
return nil, fmt.Errorf("unsupported database type: %s", f.cfg.Database.Type)
}
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
sqlDB, err := db.DB()
if err != nil {
return nil, err
}
if f.cfg.Database.MaxOpenConns > 0 {
sqlDB.SetMaxOpenConns(f.cfg.Database.MaxOpenConns)
}
if f.cfg.Database.MaxIdleConns > 0 {
sqlDB.SetMaxIdleConns(f.cfg.Database.MaxIdleConns)
}
if f.cfg.Database.ConnMaxLifetime > 0 {
sqlDB.SetConnMaxLifetime(time.Duration(f.cfg.Database.ConnMaxLifetime) * time.Second)
}
f.db = db
return db, nil
}
// Database 获取 GORM 数据库连接
func (f *Factory) Database() (*gorm.DB, error) {
return f.getDatabase()
}
func (f *Factory) getRedis() (*redis.Client, error) {
if f.redis != nil {
return f.redis, nil
}
f.mu.Lock()
defer f.mu.Unlock()
if f.redis != nil {
return f.redis, nil
}
if f.cfg == nil || f.cfg.Redis == nil {
return nil, fmt.Errorf("redis config is nil")
}
addr := f.cfg.GetRedisAddr()
if addr == "" {
return nil, fmt.Errorf("redis address is empty")
}
rc := f.cfg.Redis
client := redis.NewClient(&redis.Options{
Addr: addr,
Password: rc.Password,
DB: rc.Database,
PoolSize: rc.PoolSize,
MinIdleConns: rc.MinIdleConns,
MaxRetries: rc.MaxRetries,
DialTimeout: time.Duration(rc.DialTimeout) * time.Second,
ReadTimeout: time.Duration(rc.ReadTimeout) * time.Second,
WriteTimeout: time.Duration(rc.WriteTimeout) * time.Second,
})
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(rc.DialTimeout)*time.Second)
defer cancel()
if _, err := client.Ping(ctx).Result(); err != nil {
_ = client.Close()
return nil, fmt.Errorf("failed to connect to redis: %w", err)
}
f.redis = client
return client, nil
}
// Redis 获取 Redis 客户端
func (f *Factory) Redis() (*redis.Client, error) {
return f.getRedis()
}
func (f *Factory) getStorage() (storage.Storage, error) {
if f.storage != nil {
return f.storage, nil
}
f.mu.Lock()
defer f.mu.Unlock()
if f.storage != nil {
return f.storage, nil
}
if f.cfg == nil {
return nil, fmt.Errorf("config is nil")
}
var storageType storage.StorageType
switch {
case f.cfg.GetLocalStorage() != nil:
storageType = storage.StorageTypeLocal
case f.cfg.MinIO != nil:
storageType = storage.StorageTypeMinIO
case f.cfg.OSS != nil:
storageType = storage.StorageTypeOSS
default:
return nil, fmt.Errorf("no storage config found (LocalStorage, OSS or MinIO)")
}
s, err := storage.NewStorage(storageType, f.cfg)
if err != nil {
return nil, err
}
f.storage = s
return s, nil
}
// Storage 获取存储对象
func (f *Factory) Storage() (storage.Storage, error) {
return f.getStorage()
}
func (f *Factory) getEmail() (*email.Email, error) {
if f.email != nil {
return f.email, nil
}
f.mu.Lock()
defer f.mu.Unlock()
if f.email != nil {
return f.email, nil
}
if f.cfg == nil || f.cfg.Email == nil {
return nil, fmt.Errorf("email config is nil")
}
f.email = email.NewEmail(f.cfg)
return f.email, nil
}
// Email 获取邮件客户端
func (f *Factory) Email() (*email.Email, error) {
return f.getEmail()
}
func (f *Factory) getSMS() (*sms.SMS, error) {
if f.sms != nil {
return f.sms, nil
}
f.mu.Lock()
defer f.mu.Unlock()
if f.sms != nil {
return f.sms, nil
}
if f.cfg == nil || f.cfg.SMS == nil {
return nil, fmt.Errorf("sms config is nil")
}
f.sms = sms.NewSMS(f.cfg)
return f.sms, nil
}
// SMS 获取短信客户端
func (f *Factory) SMS() (*sms.SMS, error) {
return f.getSMS()
}
func (f *Factory) getI18n() (*i18n.I18n, error) {
if f.i18n != nil {
return f.i18n, nil
}
f.mu.Lock()
defer f.mu.Unlock()
if f.i18n != nil {
return f.i18n, nil
}
if f.cfg == nil || f.cfg.I18n == nil {
return nil, fmt.Errorf("i18n config is nil")
}
i := i18n.NewI18n(f.cfg.I18n.DefaultLang)
if f.cfg.I18n.LocalesDir != "" {
if err := i.LoadFromDir(f.cfg.I18n.LocalesDir); err != nil {
return nil, err
}
}
f.i18n = i
return i, nil
}
// I18n 获取国际化对象
func (f *Factory) I18n() (*i18n.I18n, error) {
return f.getI18n()
}
func (f *Factory) getExcel() *excel.Excel {
if f.excel != nil {
return f.excel
}
f.mu.Lock()
defer f.mu.Unlock()
if f.excel == nil {
f.excel = excel.NewExcel()
}
return f.excel
}
// Excel 获取 Excel 导出器
func (f *Factory) Excel() *excel.Excel {
return f.getExcel()
}
// MiddlewareChain 获取默认中间件链
func (f *Factory) MiddlewareChain() *middleware.Chain {
if f.chain != nil {
return f.chain
}
f.mu.Lock()
defer f.mu.Unlock()
if f.chain != nil {
return f.chain
}
var mws []func(http.Handler) http.Handler
l, _ := f.getLogger()
i18nInst, _ := f.getI18n()
mws = append(mws, middleware.Recovery(&middleware.RecoveryConfig{
Logger: l,
I18n: i18nInst,
}))
mws = append(mws, middleware.RequestID())
mws = append(mws, middleware.Logging(&middleware.LoggingConfig{Logger: l}))
if f.cfg != nil && f.cfg.RateLimit != nil && f.cfg.RateLimit.Enable {
limiter := middleware.NewTokenBucketLimiter(
f.cfg.RateLimit.Rate,
time.Duration(f.cfg.RateLimit.Period)*time.Second,
)
var keyFunc func(r *http.Request) string
if f.cfg.RateLimit.ByIP {
keyFunc = func(r *http.Request) string { return middleware.GetClientIP(r) }
} else if f.cfg.RateLimit.ByUserID {
keyFunc = func(r *http.Request) string { return r.Header.Get("X-User-ID") }
}
mws = append(mws, middleware.RateLimit(&middleware.RateLimitConfig{
Limiter: limiter,
KeyFunc: keyFunc,
}))
}
if f.cfg != nil && f.cfg.CORS != nil {
mws = append(mws, middleware.CORS(&middleware.CORSConfig{
AllowedOrigins: f.cfg.CORS.AllowedOrigins,
AllowedMethods: f.cfg.CORS.AllowedMethods,
AllowedHeaders: f.cfg.CORS.AllowedHeaders,
ExposedHeaders: f.cfg.CORS.ExposedHeaders,
AllowCredentials: f.cfg.CORS.AllowCredentials,
MaxAge: f.cfg.CORS.MaxAge,
}))
}
mws = append(mws, middleware.Language, middleware.Timezone)
f.chain = middleware.NewChain(mws...)
return f.chain
}
// Migrator 创建迁移器并加载指定目录下的 SQL 文件
func (f *Factory) Migrator(migrationsDir string) (*migration.Migrator, error) {
db, err := f.getDatabase()
if err != nil {
return nil, err
}
dbType := "mysql"
if f.cfg.Database != nil && f.cfg.Database.Type != "" {
dbType = f.cfg.Database.Type
}
m := migration.NewMigratorWithType(db, dbType)
migrations, err := migration.LoadMigrationsFromFiles(migrationsDir, "*.sql")
if err != nil {
return nil, err
}
m.AddMigrations(migrations...)
return m, nil
}