405 lines
9.3 KiB
Go
405 lines
9.3 KiB
Go
package factory
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"net/http"
|
||
"sync"
|
||
"time"
|
||
|
||
"git.toowon.com/jimmy/go-common/config"
|
||
"git.toowon.com/jimmy/go-common/email"
|
||
"git.toowon.com/jimmy/go-common/excel"
|
||
"git.toowon.com/jimmy/go-common/i18n"
|
||
"git.toowon.com/jimmy/go-common/logger"
|
||
"git.toowon.com/jimmy/go-common/middleware"
|
||
"git.toowon.com/jimmy/go-common/migration"
|
||
"git.toowon.com/jimmy/go-common/sms"
|
||
"git.toowon.com/jimmy/go-common/storage"
|
||
"github.com/redis/go-redis/v9"
|
||
"gorm.io/driver/mysql"
|
||
"gorm.io/driver/postgres"
|
||
"gorm.io/driver/sqlite"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
var (
|
||
defaultFactory *Factory
|
||
)
|
||
|
||
// Factory 工具库入口:配置加载 + lazy getter
|
||
type Factory struct {
|
||
cfg *config.Config
|
||
storage storage.Storage
|
||
|
||
logger *logger.Logger
|
||
email *email.Email
|
||
sms *sms.SMS
|
||
db *gorm.DB
|
||
redis *redis.Client
|
||
i18n *i18n.I18n
|
||
excel *excel.Excel
|
||
chain *middleware.Chain
|
||
|
||
mu sync.Mutex
|
||
}
|
||
|
||
// Option Factory 可选项(支持重载模块实现)
|
||
type Option func(*Factory)
|
||
|
||
// WithStorage 注入自定义存储实现
|
||
func WithStorage(s storage.Storage) Option {
|
||
return func(f *Factory) {
|
||
f.storage = s
|
||
}
|
||
}
|
||
|
||
// Init 从配置文件初始化全局 Factory(启动时调用一次)
|
||
func Init(filePath string, opts ...Option) error {
|
||
cfg, err := config.LoadFromFile(filePath)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defaultFactory = New(cfg, opts...)
|
||
if l, err := defaultFactory.Logger(); err == nil {
|
||
logger.SetDefaultLogger(l)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// Default 获取全局 Factory
|
||
func Default() *Factory {
|
||
return defaultFactory
|
||
}
|
||
|
||
// New 从配置创建 Factory
|
||
func New(cfg *config.Config, opts ...Option) *Factory {
|
||
f := &Factory{cfg: cfg}
|
||
for _, opt := range opts {
|
||
opt(f)
|
||
}
|
||
return f
|
||
}
|
||
|
||
// Config 获取配置
|
||
func (f *Factory) Config() *config.Config {
|
||
return f.cfg
|
||
}
|
||
|
||
func (f *Factory) getLogger() (*logger.Logger, error) {
|
||
if f.logger != nil {
|
||
return f.logger, nil
|
||
}
|
||
f.mu.Lock()
|
||
defer f.mu.Unlock()
|
||
if f.logger != nil {
|
||
return f.logger, nil
|
||
}
|
||
var cfg *config.LoggerConfig
|
||
if f.cfg != nil {
|
||
cfg = f.cfg.Logger
|
||
}
|
||
l, err := logger.NewLogger(cfg)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create logger: %w", err)
|
||
}
|
||
f.logger = l
|
||
return l, nil
|
||
}
|
||
|
||
// Logger 获取日志对象
|
||
func (f *Factory) Logger() (*logger.Logger, error) {
|
||
return f.getLogger()
|
||
}
|
||
|
||
func (f *Factory) getDatabase() (*gorm.DB, error) {
|
||
if f.db != nil {
|
||
return f.db, nil
|
||
}
|
||
f.mu.Lock()
|
||
defer f.mu.Unlock()
|
||
if f.db != nil {
|
||
return f.db, nil
|
||
}
|
||
if f.cfg == nil || f.cfg.Database == nil {
|
||
return nil, fmt.Errorf("database config is nil")
|
||
}
|
||
dsn, err := f.cfg.GetDatabaseDSN()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
var db *gorm.DB
|
||
switch f.cfg.Database.Type {
|
||
case "mysql":
|
||
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
|
||
case "postgres":
|
||
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||
case "sqlite":
|
||
db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||
default:
|
||
return nil, fmt.Errorf("unsupported database type: %s", f.cfg.Database.Type)
|
||
}
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||
}
|
||
sqlDB, err := db.DB()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if f.cfg.Database.MaxOpenConns > 0 {
|
||
sqlDB.SetMaxOpenConns(f.cfg.Database.MaxOpenConns)
|
||
}
|
||
if f.cfg.Database.MaxIdleConns > 0 {
|
||
sqlDB.SetMaxIdleConns(f.cfg.Database.MaxIdleConns)
|
||
}
|
||
if f.cfg.Database.ConnMaxLifetime > 0 {
|
||
sqlDB.SetConnMaxLifetime(time.Duration(f.cfg.Database.ConnMaxLifetime) * time.Second)
|
||
}
|
||
f.db = db
|
||
return db, nil
|
||
}
|
||
|
||
// Database 获取 GORM 数据库连接
|
||
func (f *Factory) Database() (*gorm.DB, error) {
|
||
return f.getDatabase()
|
||
}
|
||
|
||
func (f *Factory) getRedis() (*redis.Client, error) {
|
||
if f.redis != nil {
|
||
return f.redis, nil
|
||
}
|
||
f.mu.Lock()
|
||
defer f.mu.Unlock()
|
||
if f.redis != nil {
|
||
return f.redis, nil
|
||
}
|
||
if f.cfg == nil || f.cfg.Redis == nil {
|
||
return nil, fmt.Errorf("redis config is nil")
|
||
}
|
||
addr := f.cfg.GetRedisAddr()
|
||
if addr == "" {
|
||
return nil, fmt.Errorf("redis address is empty")
|
||
}
|
||
rc := f.cfg.Redis
|
||
client := redis.NewClient(&redis.Options{
|
||
Addr: addr,
|
||
Password: rc.Password,
|
||
DB: rc.Database,
|
||
PoolSize: rc.PoolSize,
|
||
MinIdleConns: rc.MinIdleConns,
|
||
MaxRetries: rc.MaxRetries,
|
||
DialTimeout: time.Duration(rc.DialTimeout) * time.Second,
|
||
ReadTimeout: time.Duration(rc.ReadTimeout) * time.Second,
|
||
WriteTimeout: time.Duration(rc.WriteTimeout) * time.Second,
|
||
})
|
||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(rc.DialTimeout)*time.Second)
|
||
defer cancel()
|
||
if _, err := client.Ping(ctx).Result(); err != nil {
|
||
_ = client.Close()
|
||
return nil, fmt.Errorf("failed to connect to redis: %w", err)
|
||
}
|
||
f.redis = client
|
||
return client, nil
|
||
}
|
||
|
||
// Redis 获取 Redis 客户端
|
||
func (f *Factory) Redis() (*redis.Client, error) {
|
||
return f.getRedis()
|
||
}
|
||
|
||
func (f *Factory) getStorage() (storage.Storage, error) {
|
||
if f.storage != nil {
|
||
return f.storage, nil
|
||
}
|
||
f.mu.Lock()
|
||
defer f.mu.Unlock()
|
||
if f.storage != nil {
|
||
return f.storage, nil
|
||
}
|
||
if f.cfg == nil {
|
||
return nil, fmt.Errorf("config is nil")
|
||
}
|
||
var storageType storage.StorageType
|
||
switch {
|
||
case f.cfg.GetLocalStorage() != nil:
|
||
storageType = storage.StorageTypeLocal
|
||
case f.cfg.MinIO != nil:
|
||
storageType = storage.StorageTypeMinIO
|
||
case f.cfg.OSS != nil:
|
||
storageType = storage.StorageTypeOSS
|
||
default:
|
||
return nil, fmt.Errorf("no storage config found (LocalStorage, OSS or MinIO)")
|
||
}
|
||
s, err := storage.NewStorage(storageType, f.cfg)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
f.storage = s
|
||
return s, nil
|
||
}
|
||
|
||
// Storage 获取存储对象
|
||
func (f *Factory) Storage() (storage.Storage, error) {
|
||
return f.getStorage()
|
||
}
|
||
|
||
func (f *Factory) getEmail() (*email.Email, error) {
|
||
if f.email != nil {
|
||
return f.email, nil
|
||
}
|
||
f.mu.Lock()
|
||
defer f.mu.Unlock()
|
||
if f.email != nil {
|
||
return f.email, nil
|
||
}
|
||
if f.cfg == nil || f.cfg.Email == nil {
|
||
return nil, fmt.Errorf("email config is nil")
|
||
}
|
||
f.email = email.NewEmail(f.cfg)
|
||
return f.email, nil
|
||
}
|
||
|
||
// Email 获取邮件客户端
|
||
func (f *Factory) Email() (*email.Email, error) {
|
||
return f.getEmail()
|
||
}
|
||
|
||
func (f *Factory) getSMS() (*sms.SMS, error) {
|
||
if f.sms != nil {
|
||
return f.sms, nil
|
||
}
|
||
f.mu.Lock()
|
||
defer f.mu.Unlock()
|
||
if f.sms != nil {
|
||
return f.sms, nil
|
||
}
|
||
if f.cfg == nil || f.cfg.SMS == nil {
|
||
return nil, fmt.Errorf("sms config is nil")
|
||
}
|
||
f.sms = sms.NewSMS(f.cfg)
|
||
return f.sms, nil
|
||
}
|
||
|
||
// SMS 获取短信客户端
|
||
func (f *Factory) SMS() (*sms.SMS, error) {
|
||
return f.getSMS()
|
||
}
|
||
|
||
func (f *Factory) getI18n() (*i18n.I18n, error) {
|
||
if f.i18n != nil {
|
||
return f.i18n, nil
|
||
}
|
||
f.mu.Lock()
|
||
defer f.mu.Unlock()
|
||
if f.i18n != nil {
|
||
return f.i18n, nil
|
||
}
|
||
if f.cfg == nil || f.cfg.I18n == nil {
|
||
return nil, fmt.Errorf("i18n config is nil")
|
||
}
|
||
i := i18n.NewI18n(f.cfg.I18n.DefaultLang)
|
||
if f.cfg.I18n.LocalesDir != "" {
|
||
if err := i.LoadFromDir(f.cfg.I18n.LocalesDir); err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
f.i18n = i
|
||
return i, nil
|
||
}
|
||
|
||
// I18n 获取国际化对象
|
||
func (f *Factory) I18n() (*i18n.I18n, error) {
|
||
return f.getI18n()
|
||
}
|
||
|
||
func (f *Factory) getExcel() *excel.Excel {
|
||
if f.excel != nil {
|
||
return f.excel
|
||
}
|
||
f.mu.Lock()
|
||
defer f.mu.Unlock()
|
||
if f.excel == nil {
|
||
f.excel = excel.NewExcel()
|
||
}
|
||
return f.excel
|
||
}
|
||
|
||
// Excel 获取 Excel 导出器
|
||
func (f *Factory) Excel() *excel.Excel {
|
||
return f.getExcel()
|
||
}
|
||
|
||
// MiddlewareChain 获取默认中间件链
|
||
func (f *Factory) MiddlewareChain() *middleware.Chain {
|
||
if f.chain != nil {
|
||
return f.chain
|
||
}
|
||
f.mu.Lock()
|
||
defer f.mu.Unlock()
|
||
if f.chain != nil {
|
||
return f.chain
|
||
}
|
||
|
||
var mws []func(http.Handler) http.Handler
|
||
l, _ := f.getLogger()
|
||
i18nInst, _ := f.getI18n()
|
||
|
||
mws = append(mws, middleware.Recovery(&middleware.RecoveryConfig{
|
||
Logger: l,
|
||
I18n: i18nInst,
|
||
}))
|
||
mws = append(mws, middleware.RequestID())
|
||
mws = append(mws, middleware.Logging(&middleware.LoggingConfig{Logger: l}))
|
||
|
||
if f.cfg != nil && f.cfg.RateLimit != nil && f.cfg.RateLimit.Enable {
|
||
limiter := middleware.NewTokenBucketLimiter(
|
||
f.cfg.RateLimit.Rate,
|
||
time.Duration(f.cfg.RateLimit.Period)*time.Second,
|
||
)
|
||
var keyFunc func(r *http.Request) string
|
||
if f.cfg.RateLimit.ByIP {
|
||
keyFunc = func(r *http.Request) string { return middleware.GetClientIP(r) }
|
||
} else if f.cfg.RateLimit.ByUserID {
|
||
keyFunc = func(r *http.Request) string { return r.Header.Get("X-User-ID") }
|
||
}
|
||
mws = append(mws, middleware.RateLimit(&middleware.RateLimitConfig{
|
||
Limiter: limiter,
|
||
KeyFunc: keyFunc,
|
||
}))
|
||
}
|
||
|
||
if f.cfg != nil && f.cfg.CORS != nil {
|
||
mws = append(mws, middleware.CORS(&middleware.CORSConfig{
|
||
AllowedOrigins: f.cfg.CORS.AllowedOrigins,
|
||
AllowedMethods: f.cfg.CORS.AllowedMethods,
|
||
AllowedHeaders: f.cfg.CORS.AllowedHeaders,
|
||
ExposedHeaders: f.cfg.CORS.ExposedHeaders,
|
||
AllowCredentials: f.cfg.CORS.AllowCredentials,
|
||
MaxAge: f.cfg.CORS.MaxAge,
|
||
}))
|
||
}
|
||
|
||
mws = append(mws, middleware.Language, middleware.Timezone)
|
||
f.chain = middleware.NewChain(mws...)
|
||
return f.chain
|
||
}
|
||
|
||
// Migrator 创建迁移器并加载指定目录下的 SQL 文件
|
||
func (f *Factory) Migrator(migrationsDir string) (*migration.Migrator, error) {
|
||
db, err := f.getDatabase()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
dbType := "mysql"
|
||
if f.cfg.Database != nil && f.cfg.Database.Type != "" {
|
||
dbType = f.cfg.Database.Type
|
||
}
|
||
m := migration.NewMigratorWithType(db, dbType)
|
||
migrations, err := migration.LoadMigrationsFromFiles(migrationsDir, "*.sql")
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
m.AddMigrations(migrations...)
|
||
return m, nil
|
||
}
|