Files
go-common/factory/factory.go
2025-12-05 00:07:15 +08:00

767 lines
21 KiB
Go
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package factory
import (
"context"
"fmt"
"io"
"net/http"
"time"
"git.toowon.com/jimmy/go-common/config"
"git.toowon.com/jimmy/go-common/email"
"git.toowon.com/jimmy/go-common/logger"
"git.toowon.com/jimmy/go-common/middleware"
"git.toowon.com/jimmy/go-common/sms"
"git.toowon.com/jimmy/go-common/storage"
"github.com/redis/go-redis/v9"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// Factory 工厂类 - 黑盒模式设计
//
// 核心理念:
//
// 外部项目只需传递一个配置文件路径,即可直接使用所有功能,
// 无需关心内部实现细节。
//
// 推荐使用的黑盒方法:
// - GetMiddlewareChain():获取配置好的中间件链
// - LogInfo(), LogError():记录日志
// - RedisSet(), RedisGet()操作Redis
// - SendEmail(), SendSMS():发送邮件和短信
// - UploadFile(), GetFileURL():文件上传和访问
//
// 需要获取客户端对象的场景(高级功能):
// - GetDatabase()数据库操作GORM已经是很好的抽象
// - GetRedisClient()Redis高级操作Hash, List, Set, ZSet等
// - GetLogger()Logger高级功能Close等
//
// 使用示例:
//
// // 1. 创建工厂(传入配置文件路径)
// fac, _ := factory.NewFactoryFromFile("config.json")
//
// // 2. 直接使用黑盒方法(推荐)
// fac.LogInfo("用户登录成功")
// fac.RedisSet(ctx, "session:123", "data", time.Hour)
// fac.SendEmail([]string{"user@example.com"}, "主题", "内容")
// chain := fac.GetMiddlewareChain()
// chain.Append(yourAuthMiddleware) // 添加自定义中间件
//
// // 3. 获取客户端对象(仅在需要高级功能时)
// db, _ := fac.GetDatabase()
// db.Find(&users)
//
// Factory 工厂类,用于从配置创建各种客户端对象
type Factory struct {
cfg *config.Config
storage storage.Storage // 存储实例(延迟初始化)
logger *logger.Logger // 日志实例(延迟初始化)
email *email.Email // 邮件客户端(延迟初始化)
sms *sms.SMS // 短信客户端(延迟初始化)
db *gorm.DB // 数据库连接(延迟初始化)
redis *redis.Client // Redis客户端延迟初始化
}
// NewFactory 创建工厂实例
func NewFactory(cfg *config.Config) *Factory {
return &Factory{
cfg: cfg,
}
}
// NewFactoryFromFile 从配置文件创建工厂实例(便捷方法)
// filePath: 配置文件路径
func NewFactoryFromFile(filePath string) (*Factory, error) {
cfg, err := config.LoadFromFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to load config: %w", err)
}
return NewFactory(cfg), nil
}
// getEmailClient 获取邮件客户端(内部方法,延迟初始化)
func (f *Factory) getEmailClient() (*email.Email, error) {
if f.email != nil {
return f.email, nil
}
if f.cfg.Email == nil {
return nil, fmt.Errorf("email config is nil")
}
e, err := email.NewEmail(f.cfg.Email)
if err != nil {
return nil, fmt.Errorf("failed to create email client: %w", err)
}
f.email = e
return e, nil
}
// SendEmail 发送邮件(黑盒模式,推荐使用)
// 自动使用配置文件中的SMTP配置发送邮件
// to: 收件人列表
// subject: 邮件主题
// body: 邮件正文(纯文本)
// htmlBody: HTML正文可选如果设置了会优先使用
func (f *Factory) SendEmail(to []string, subject, body string, htmlBody ...string) error {
e, err := f.getEmailClient()
if err != nil {
return err
}
msg := &email.Message{
To: to,
Subject: subject,
Body: body,
}
if len(htmlBody) > 0 && htmlBody[0] != "" {
msg.HTMLBody = htmlBody[0]
}
return e.Send(msg)
}
// getSMSClient 获取短信客户端(内部方法,延迟初始化)
func (f *Factory) getSMSClient() (*sms.SMS, error) {
if f.sms != nil {
return f.sms, nil
}
if f.cfg.SMS == nil {
return nil, fmt.Errorf("SMS config is nil")
}
s, err := sms.NewSMS(f.cfg.SMS)
if err != nil {
return nil, fmt.Errorf("failed to create SMS client: %w", err)
}
f.sms = s
return s, nil
}
// SendSMS 发送短信(黑盒模式,推荐使用)
// 自动使用配置文件中的阿里云短信配置发送短信
// phoneNumbers: 手机号列表
// templateParam: 模板参数map或JSON字符串
// templateCode: 模板代码(可选,如果为空使用配置中的模板代码)
func (f *Factory) SendSMS(phoneNumbers []string, templateParam interface{}, templateCode ...string) (*sms.SendResponse, error) {
s, err := f.getSMSClient()
if err != nil {
return nil, err
}
req := &sms.SendRequest{
PhoneNumbers: phoneNumbers,
TemplateParam: templateParam,
}
if len(templateCode) > 0 && templateCode[0] != "" {
req.TemplateCode = templateCode[0]
}
return s.Send(req)
}
// getLogger 获取日志记录器(内部方法,延迟初始化)
func (f *Factory) getLogger() (*logger.Logger, error) {
if f.logger != nil {
return f.logger, nil
}
var l *logger.Logger
var err error
if f.cfg.Logger == nil {
// 如果没有配置,使用默认配置创建
l, err = logger.NewLogger(nil)
} else {
l, err = logger.NewLogger(f.cfg.Logger)
}
if err != nil {
return nil, fmt.Errorf("failed to create logger: %w", err)
}
f.logger = l
return l, nil
}
// GetLogger 获取日志记录器对象(不推荐直接使用)
//
// ⚠️ 不推荐直接使用此方法,推荐使用黑盒方法:
// - LogDebug, LogInfo, LogWarn, LogError记录简单日志
// - LogDebugf, LogInfof, LogWarnf, LogErrorf记录带字段的日志
//
// 仅在以下高级场景时使用:
// - 需要调用 Close() 方法关闭logger
// - 需要使用logger的其他高级功能
//
// 示例(不推荐):
//
// logger, _ := factory.GetLogger()
// defer logger.Close()
//
// 示例(推荐):
//
// factory.LogInfo("用户登录成功")
// factory.LogErrorf(map[string]interface{}{"user_id": 123}, "登录失败")
func (f *Factory) GetLogger() (*logger.Logger, error) {
return f.getLogger()
}
// LogDebug 记录调试日志(黑盒模式,推荐使用)
// 自动使用配置文件中的logger配置
// message: 日志消息
// args: 格式化参数(可选)
func (f *Factory) LogDebug(message string, args ...interface{}) {
l, err := f.getLogger()
if err != nil {
// 如果日志初始化失败,使用标准输出
if len(args) > 0 {
fmt.Printf("[DEBUG] "+message+"\n", args...)
} else {
fmt.Printf("[DEBUG] %s\n", message)
}
return
}
if len(args) > 0 {
l.Debug(message, args...)
} else {
l.Debug(message)
}
}
// LogDebugf 记录调试日志(带字段,黑盒模式,推荐使用)
// 自动使用配置文件中的logger配置
// fields: 日志字段
// message: 日志消息
// args: 格式化参数(可选)
func (f *Factory) LogDebugf(fields map[string]interface{}, message string, args ...interface{}) {
l, err := f.getLogger()
if err != nil {
// 如果日志初始化失败,使用标准输出
if len(args) > 0 {
fmt.Printf("[DEBUG] "+message+"\n", args...)
} else {
fmt.Printf("[DEBUG] %s\n", message)
}
return
}
l.Debugf(fields, message, args...)
}
// LogInfo 记录信息日志(黑盒模式,推荐使用)
// 自动使用配置文件中的logger配置
// message: 日志消息
// args: 格式化参数(可选)
func (f *Factory) LogInfo(message string, args ...interface{}) {
l, err := f.getLogger()
if err != nil {
// 如果日志初始化失败,使用标准输出
if len(args) > 0 {
fmt.Printf("[INFO] "+message+"\n", args...)
} else {
fmt.Printf("[INFO] %s\n", message)
}
return
}
if len(args) > 0 {
l.Info(message, args...)
} else {
l.Info(message)
}
}
// LogInfof 记录信息日志(带字段,黑盒模式,推荐使用)
// 自动使用配置文件中的logger配置
// fields: 日志字段
// message: 日志消息
// args: 格式化参数(可选)
func (f *Factory) LogInfof(fields map[string]interface{}, message string, args ...interface{}) {
l, err := f.getLogger()
if err != nil {
// 如果日志初始化失败,使用标准输出
if len(args) > 0 {
fmt.Printf("[INFO] "+message+"\n", args...)
} else {
fmt.Printf("[INFO] %s\n", message)
}
return
}
l.Infof(fields, message, args...)
}
// LogWarn 记录警告日志(黑盒模式,推荐使用)
// 自动使用配置文件中的logger配置
// message: 日志消息
// args: 格式化参数(可选)
func (f *Factory) LogWarn(message string, args ...interface{}) {
l, err := f.getLogger()
if err != nil {
// 如果日志初始化失败,使用标准输出
if len(args) > 0 {
fmt.Printf("[WARN] "+message+"\n", args...)
} else {
fmt.Printf("[WARN] %s\n", message)
}
return
}
if len(args) > 0 {
l.Warn(message, args...)
} else {
l.Warn(message)
}
}
// LogWarnf 记录警告日志(带字段,黑盒模式,推荐使用)
// 自动使用配置文件中的logger配置
// fields: 日志字段
// message: 日志消息
// args: 格式化参数(可选)
func (f *Factory) LogWarnf(fields map[string]interface{}, message string, args ...interface{}) {
l, err := f.getLogger()
if err != nil {
// 如果日志初始化失败,使用标准输出
if len(args) > 0 {
fmt.Printf("[WARN] "+message+"\n", args...)
} else {
fmt.Printf("[WARN] %s\n", message)
}
return
}
l.Warnf(fields, message, args...)
}
// LogError 记录错误日志(黑盒模式,推荐使用)
// 自动使用配置文件中的logger配置
// message: 日志消息
// args: 格式化参数(可选)
func (f *Factory) LogError(message string, args ...interface{}) {
l, err := f.getLogger()
if err != nil {
// 如果日志初始化失败,使用标准输出
if len(args) > 0 {
fmt.Printf("[ERROR] "+message+"\n", args...)
} else {
fmt.Printf("[ERROR] %s\n", message)
}
return
}
if len(args) > 0 {
l.Error(message, args...)
} else {
l.Error(message)
}
}
// LogErrorf 记录错误日志(带字段,黑盒模式,推荐使用)
// 自动使用配置文件中的logger配置
// fields: 日志字段
// message: 日志消息
// args: 格式化参数(可选)
func (f *Factory) LogErrorf(fields map[string]interface{}, message string, args ...interface{}) {
l, err := f.getLogger()
if err != nil {
// 如果日志初始化失败,使用标准输出
if len(args) > 0 {
fmt.Printf("[ERROR] "+message+"\n", args...)
} else {
fmt.Printf("[ERROR] %s\n", message)
}
return
}
l.Errorf(fields, message, args...)
}
// getDatabase 获取数据库连接对象(内部方法,延迟初始化)
func (f *Factory) getDatabase() (*gorm.DB, error) {
if f.db != nil {
return f.db, nil
}
if f.cfg.Database == nil {
return nil, fmt.Errorf("database config is nil")
}
// 获取DSN
dsn, err := f.cfg.GetDatabaseDSN()
if err != nil {
return nil, fmt.Errorf("failed to get DSN: %w", err)
}
// 根据数据库类型创建连接
var db *gorm.DB
switch f.cfg.Database.Type {
case "mysql":
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
case "postgres":
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
case "sqlite":
db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{})
default:
return nil, fmt.Errorf("unsupported database type: %s", f.cfg.Database.Type)
}
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
// 配置连接池
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("failed to get sql.DB: %w", err)
}
if f.cfg.Database.MaxOpenConns > 0 {
sqlDB.SetMaxOpenConns(f.cfg.Database.MaxOpenConns)
}
if f.cfg.Database.MaxIdleConns > 0 {
sqlDB.SetMaxIdleConns(f.cfg.Database.MaxIdleConns)
}
if f.cfg.Database.ConnMaxLifetime > 0 {
sqlDB.SetConnMaxLifetime(time.Duration(f.cfg.Database.ConnMaxLifetime) * time.Second)
}
f.db = db
return db, nil
}
// GetDatabase 获取数据库连接对象(推荐使用)
// 返回已初始化的GORM数据库对象可直接使用
//
// 数据库操作保持使用 GORM 对象,因为:
// - 数据库操作非常复杂多样(查询、插入、更新、删除、事务等)
// - GORM 已经提供了很好的抽象和 API
// - 无需在 factory 中重复封装所有数据库方法
//
// 示例:
//
// db, _ := factory.GetDatabase()
// db.Find(&users)
// db.Create(&user)
// db.Transaction(func(tx *gorm.DB) error { ... })
func (f *Factory) GetDatabase() (*gorm.DB, error) {
return f.getDatabase()
}
// getRedisClient 获取Redis客户端对象内部方法延迟初始化
func (f *Factory) getRedisClient() (*redis.Client, error) {
if f.redis != nil {
return f.redis, nil
}
if f.cfg.Redis == nil {
return nil, fmt.Errorf("redis config is nil")
}
// 获取Redis地址
addr := f.cfg.GetRedisAddr()
if addr == "" {
return nil, fmt.Errorf("redis address is empty")
}
// 设置默认值
redisConfig := f.cfg.Redis
if redisConfig.PoolSize == 0 {
redisConfig.PoolSize = 10 // 默认连接池大小
}
if redisConfig.MinIdleConns == 0 {
redisConfig.MinIdleConns = 5 // 默认最小空闲连接数
}
if redisConfig.DialTimeout == 0 {
redisConfig.DialTimeout = 5 // 默认连接超时5秒
}
if redisConfig.ReadTimeout == 0 {
redisConfig.ReadTimeout = 3 // 默认读取超时3秒
}
if redisConfig.WriteTimeout == 0 {
redisConfig.WriteTimeout = 3 // 默认写入超时3秒
}
// 创建Redis客户端
client := redis.NewClient(&redis.Options{
Addr: addr,
Password: redisConfig.Password,
DB: redisConfig.Database,
PoolSize: redisConfig.PoolSize,
MinIdleConns: redisConfig.MinIdleConns,
MaxRetries: redisConfig.MaxRetries,
DialTimeout: time.Duration(redisConfig.DialTimeout) * time.Second,
ReadTimeout: time.Duration(redisConfig.ReadTimeout) * time.Second,
WriteTimeout: time.Duration(redisConfig.WriteTimeout) * time.Second,
})
// 测试连接
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(redisConfig.DialTimeout)*time.Second)
defer cancel()
_, err := client.Ping(ctx).Result()
if err != nil {
client.Close() // 连接失败时关闭客户端
return nil, fmt.Errorf("failed to connect to redis: %w", err)
}
f.redis = client
return client, nil
}
// GetRedisClient 获取Redis客户端对象高级功能时使用
// 返回已初始化的Redis客户端对象
//
// 推荐使用黑盒方法:
// - RedisGet, RedisSet, RedisDelete, RedisExists常用操作
//
// 仅在需要使用高级功能时获取客户端:
// - Hash 操作HSet, HGet, HGetAll 等)
// - List 操作LPush, RPush, LRange 等)
// - Set 操作SAdd, SMembers 等)
// - ZSet 操作ZAdd, ZRange 等)
// - 其他高级功能
//
// 示例(常用操作,推荐):
//
// factory.RedisSet(ctx, "key", "value", time.Hour)
// value, _ := factory.RedisGet(ctx, "key")
//
// 示例(高级功能):
//
// client, _ := factory.GetRedisClient()
// client.HSet(ctx, "user:1", "name", "Alice")
// client.LPush(ctx, "queue", "task1")
func (f *Factory) GetRedisClient() (*redis.Client, error) {
return f.getRedisClient()
}
// RedisGet 获取Redis值黑盒模式推荐使用
// 自动使用配置文件中的Redis配置
// key: Redis键
func (f *Factory) RedisGet(ctx context.Context, key string) (string, error) {
client, err := f.getRedisClient()
if err != nil {
return "", err
}
result, err := client.Get(ctx, key).Result()
if err == redis.Nil {
return "", nil // key不存在返回空字符串
}
if err != nil {
return "", fmt.Errorf("failed to get redis key: %w", err)
}
return result, nil
}
// RedisSet 设置Redis值黑盒模式推荐使用
// 自动使用配置文件中的Redis配置
// key: Redis键
// value: Redis值
// expiration: 过期时间可选0表示不过期
func (f *Factory) RedisSet(ctx context.Context, key string, value interface{}, expiration ...time.Duration) error {
client, err := f.getRedisClient()
if err != nil {
return err
}
var exp time.Duration
if len(expiration) > 0 {
exp = expiration[0]
}
err = client.Set(ctx, key, value, exp).Err()
if err != nil {
return fmt.Errorf("failed to set redis key: %w", err)
}
return nil
}
// RedisDelete 删除Redis键黑盒模式推荐使用
// 自动使用配置文件中的Redis配置
// keys: Redis键列表
func (f *Factory) RedisDelete(ctx context.Context, keys ...string) error {
if len(keys) == 0 {
return nil
}
client, err := f.getRedisClient()
if err != nil {
return err
}
err = client.Del(ctx, keys...).Err()
if err != nil {
return fmt.Errorf("failed to delete redis keys: %w", err)
}
return nil
}
// RedisExists 检查Redis键是否存在黑盒模式推荐使用
// 自动使用配置文件中的Redis配置
// key: Redis键
func (f *Factory) RedisExists(ctx context.Context, key string) (bool, error) {
client, err := f.getRedisClient()
if err != nil {
return false, err
}
count, err := client.Exists(ctx, key).Result()
if err != nil {
return false, fmt.Errorf("failed to check redis key existence: %w", err)
}
return count > 0, nil
}
// GetConfig 获取配置对象
func (f *Factory) GetConfig() *config.Config {
return f.cfg
}
// getStorage 获取存储实例(内部方法,延迟初始化)
func (f *Factory) getStorage() (storage.Storage, error) {
if f.storage != nil {
return f.storage, nil
}
// 根据配置自动选择存储类型
// 优先级MinIO > OSS
var storageType storage.StorageType
if f.cfg.MinIO != nil {
storageType = storage.StorageTypeMinIO
} else if f.cfg.OSS != nil {
storageType = storage.StorageTypeOSS
} else {
return nil, fmt.Errorf("no storage config found (OSS or MinIO)")
}
// 创建存储实例
s, err := storage.NewStorage(storageType, f.cfg)
if err != nil {
return nil, fmt.Errorf("failed to create storage: %w", err)
}
f.storage = s
return s, nil
}
// UploadFile 上传文件(黑盒模式,推荐使用)
// 自动根据配置选择存储类型OSS 或 MinIO无需关心内部实现
// ctx: 上下文
// objectKey: 对象键(文件路径)
// reader: 文件内容
// contentType: 文件类型(可选)
// 返回文件访问URL和错误
func (f *Factory) UploadFile(ctx context.Context, objectKey string, reader io.Reader, contentType ...string) (string, error) {
s, err := f.getStorage()
if err != nil {
return "", err
}
// 上传文件
err = s.Upload(ctx, objectKey, reader, contentType...)
if err != nil {
return "", fmt.Errorf("failed to upload file: %w", err)
}
// 获取文件URL
url, err := s.GetURL(objectKey, 0)
if err != nil {
return "", fmt.Errorf("failed to get file URL: %w", err)
}
return url, nil
}
// GetFileURL 获取文件访问URL黑盒模式推荐使用
// 自动根据配置选择存储类型返回文件的访问URL
// objectKey: 对象键(文件路径)
// expires: 过期时间0表示永久有效
func (f *Factory) GetFileURL(objectKey string, expires int64) (string, error) {
s, err := f.getStorage()
if err != nil {
return "", err
}
return s.GetURL(objectKey, expires)
}
// GetMiddlewareChain 获取配置好的中间件链(黑盒模式)
// 自动包含Recovery、Logging、RateLimit如果配置了、CORS如果配置了、Timezone
// 返回已配置好的中间件链,可以通过 Append() 方法添加自定义中间件
//
// 示例1直接使用
//
// chain := factory.GetMiddlewareChain()
// http.Handle("/api/users", chain.ThenFunc(handleUsers))
//
// 示例2添加自定义中间件
//
// chain := factory.GetMiddlewareChain()
// chain.Append(yourCustomMiddleware1, yourCustomMiddleware2)
// http.Handle("/api/users", chain.ThenFunc(handleUsers))
func (f *Factory) GetMiddlewareChain() *middleware.Chain {
var middlewares []func(http.Handler) http.Handler
// 1. Recovery 中间件必需防止panic导致服务崩溃
l, _ := f.getLogger() // 获取logger如果失败会使用默认logger
middlewares = append(middlewares, middleware.Recovery(&middleware.RecoveryConfig{
Logger: l,
}))
// 2. Logging 中间件(必需,记录所有请求)
middlewares = append(middlewares, middleware.Logging(&middleware.LoggingConfig{
Logger: l,
}))
// 3. RateLimit 中间件(如果配置了限流)
if f.cfg != nil && f.cfg.RateLimit != nil {
if f.cfg.RateLimit.Enable {
// 从配置创建限流中间件
limiter := middleware.NewTokenBucketLimiter(
f.cfg.RateLimit.Rate,
time.Duration(f.cfg.RateLimit.Period)*time.Second,
)
var keyFunc func(r *http.Request) string
if f.cfg.RateLimit.ByIP {
keyFunc = func(r *http.Request) string {
return middleware.GetClientIP(r)
}
} else if f.cfg.RateLimit.ByUserID {
keyFunc = func(r *http.Request) string {
return r.Header.Get("X-User-ID")
}
}
middlewares = append(middlewares, middleware.RateLimit(&middleware.RateLimitConfig{
Limiter: limiter,
KeyFunc: keyFunc,
}))
}
}
// 4. CORS 中间件(如果配置了)
if f.cfg != nil && f.cfg.CORS != nil {
corsConfig := &middleware.CORSConfig{
AllowedOrigins: f.cfg.CORS.AllowedOrigins,
AllowedMethods: f.cfg.CORS.AllowedMethods,
AllowedHeaders: f.cfg.CORS.AllowedHeaders,
ExposedHeaders: f.cfg.CORS.ExposedHeaders,
AllowCredentials: f.cfg.CORS.AllowCredentials,
MaxAge: f.cfg.CORS.MaxAge,
}
middlewares = append(middlewares, middleware.CORS(corsConfig))
}
// 5. Timezone 中间件(必需,处理时区)
middlewares = append(middlewares, middleware.Timezone)
return middleware.NewChain(middlewares...)
}