597 lines
15 KiB
Go
597 lines
15 KiB
Go
package factory
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"io"
|
||
"time"
|
||
|
||
"git.toowon.com/jimmy/go-common/config"
|
||
"git.toowon.com/jimmy/go-common/email"
|
||
"git.toowon.com/jimmy/go-common/logger"
|
||
"git.toowon.com/jimmy/go-common/sms"
|
||
"git.toowon.com/jimmy/go-common/storage"
|
||
"github.com/redis/go-redis/v9"
|
||
"gorm.io/driver/mysql"
|
||
"gorm.io/driver/postgres"
|
||
"gorm.io/driver/sqlite"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// Factory 工厂类,用于从配置创建各种客户端对象
|
||
type Factory struct {
|
||
cfg *config.Config
|
||
storage storage.Storage // 存储实例(延迟初始化)
|
||
logger *logger.Logger // 日志实例(延迟初始化)
|
||
email *email.Email // 邮件客户端(延迟初始化)
|
||
sms *sms.SMS // 短信客户端(延迟初始化)
|
||
db *gorm.DB // 数据库连接(延迟初始化)
|
||
redis *redis.Client // Redis客户端(延迟初始化)
|
||
}
|
||
|
||
// NewFactory 创建工厂实例
|
||
func NewFactory(cfg *config.Config) *Factory {
|
||
return &Factory{
|
||
cfg: cfg,
|
||
}
|
||
}
|
||
|
||
// NewFactoryFromFile 从配置文件创建工厂实例(便捷方法)
|
||
// filePath: 配置文件路径
|
||
func NewFactoryFromFile(filePath string) (*Factory, error) {
|
||
cfg, err := config.LoadFromFile(filePath)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to load config: %w", err)
|
||
}
|
||
return NewFactory(cfg), nil
|
||
}
|
||
|
||
// getEmailClient 获取邮件客户端(内部方法,延迟初始化)
|
||
func (f *Factory) getEmailClient() (*email.Email, error) {
|
||
if f.email != nil {
|
||
return f.email, nil
|
||
}
|
||
|
||
if f.cfg.Email == nil {
|
||
return nil, fmt.Errorf("email config is nil")
|
||
}
|
||
|
||
e, err := email.NewEmail(f.cfg.Email)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create email client: %w", err)
|
||
}
|
||
|
||
f.email = e
|
||
return e, nil
|
||
}
|
||
|
||
// SendEmail 发送邮件(黑盒模式)
|
||
// to: 收件人列表
|
||
// subject: 邮件主题
|
||
// body: 邮件正文(纯文本)
|
||
// htmlBody: HTML正文(可选,如果设置了会优先使用)
|
||
func (f *Factory) SendEmail(to []string, subject, body string, htmlBody ...string) error {
|
||
e, err := f.getEmailClient()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
msg := &email.Message{
|
||
To: to,
|
||
Subject: subject,
|
||
Body: body,
|
||
}
|
||
|
||
if len(htmlBody) > 0 && htmlBody[0] != "" {
|
||
msg.HTMLBody = htmlBody[0]
|
||
}
|
||
|
||
return e.Send(msg)
|
||
}
|
||
|
||
// getSMSClient 获取短信客户端(内部方法,延迟初始化)
|
||
func (f *Factory) getSMSClient() (*sms.SMS, error) {
|
||
if f.sms != nil {
|
||
return f.sms, nil
|
||
}
|
||
|
||
if f.cfg.SMS == nil {
|
||
return nil, fmt.Errorf("SMS config is nil")
|
||
}
|
||
|
||
s, err := sms.NewSMS(f.cfg.SMS)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create SMS client: %w", err)
|
||
}
|
||
|
||
f.sms = s
|
||
return s, nil
|
||
}
|
||
|
||
// SendSMS 发送短信(黑盒模式)
|
||
// phoneNumbers: 手机号列表
|
||
// templateParam: 模板参数(map或JSON字符串)
|
||
// templateCode: 模板代码(可选,如果为空使用配置中的模板代码)
|
||
func (f *Factory) SendSMS(phoneNumbers []string, templateParam interface{}, templateCode ...string) (*sms.SendResponse, error) {
|
||
s, err := f.getSMSClient()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
req := &sms.SendRequest{
|
||
PhoneNumbers: phoneNumbers,
|
||
TemplateParam: templateParam,
|
||
}
|
||
|
||
if len(templateCode) > 0 && templateCode[0] != "" {
|
||
req.TemplateCode = templateCode[0]
|
||
}
|
||
|
||
return s.Send(req)
|
||
}
|
||
|
||
// getLogger 获取日志记录器(内部方法,延迟初始化)
|
||
func (f *Factory) getLogger() (*logger.Logger, error) {
|
||
if f.logger != nil {
|
||
return f.logger, nil
|
||
}
|
||
|
||
var l *logger.Logger
|
||
var err error
|
||
if f.cfg.Logger == nil {
|
||
// 如果没有配置,使用默认配置创建
|
||
l, err = logger.NewLogger(nil)
|
||
} else {
|
||
l, err = logger.NewLogger(f.cfg.Logger)
|
||
}
|
||
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create logger: %w", err)
|
||
}
|
||
|
||
f.logger = l
|
||
return l, nil
|
||
}
|
||
|
||
// GetLogger 获取日志记录器对象(已初始化)
|
||
// 返回已初始化的日志记录器对象,可直接使用
|
||
// 注意:推荐使用 LogDebug、LogInfo、LogWarn、LogError 等方法直接记录日志
|
||
// 如果需要使用logger的高级功能(如Close方法),可以使用此方法获取logger对象
|
||
func (f *Factory) GetLogger() (*logger.Logger, error) {
|
||
return f.getLogger()
|
||
}
|
||
|
||
// LogDebug 记录调试日志
|
||
// message: 日志消息
|
||
// args: 格式化参数(可选)
|
||
func (f *Factory) LogDebug(message string, args ...interface{}) {
|
||
l, err := f.getLogger()
|
||
if err != nil {
|
||
// 如果日志初始化失败,使用标准输出
|
||
if len(args) > 0 {
|
||
fmt.Printf("[DEBUG] "+message+"\n", args...)
|
||
} else {
|
||
fmt.Printf("[DEBUG] %s\n", message)
|
||
}
|
||
return
|
||
}
|
||
if len(args) > 0 {
|
||
l.Debug(message, args...)
|
||
} else {
|
||
l.Debug(message)
|
||
}
|
||
}
|
||
|
||
// LogDebugf 记录调试日志(带字段)
|
||
// fields: 日志字段
|
||
// message: 日志消息
|
||
// args: 格式化参数(可选)
|
||
func (f *Factory) LogDebugf(fields map[string]interface{}, message string, args ...interface{}) {
|
||
l, err := f.getLogger()
|
||
if err != nil {
|
||
// 如果日志初始化失败,使用标准输出
|
||
if len(args) > 0 {
|
||
fmt.Printf("[DEBUG] "+message+"\n", args...)
|
||
} else {
|
||
fmt.Printf("[DEBUG] %s\n", message)
|
||
}
|
||
return
|
||
}
|
||
l.Debugf(fields, message, args...)
|
||
}
|
||
|
||
// LogInfo 记录信息日志
|
||
// message: 日志消息
|
||
// args: 格式化参数(可选)
|
||
func (f *Factory) LogInfo(message string, args ...interface{}) {
|
||
l, err := f.getLogger()
|
||
if err != nil {
|
||
// 如果日志初始化失败,使用标准输出
|
||
if len(args) > 0 {
|
||
fmt.Printf("[INFO] "+message+"\n", args...)
|
||
} else {
|
||
fmt.Printf("[INFO] %s\n", message)
|
||
}
|
||
return
|
||
}
|
||
if len(args) > 0 {
|
||
l.Info(message, args...)
|
||
} else {
|
||
l.Info(message)
|
||
}
|
||
}
|
||
|
||
// LogInfof 记录信息日志(带字段)
|
||
// fields: 日志字段
|
||
// message: 日志消息
|
||
// args: 格式化参数(可选)
|
||
func (f *Factory) LogInfof(fields map[string]interface{}, message string, args ...interface{}) {
|
||
l, err := f.getLogger()
|
||
if err != nil {
|
||
// 如果日志初始化失败,使用标准输出
|
||
if len(args) > 0 {
|
||
fmt.Printf("[INFO] "+message+"\n", args...)
|
||
} else {
|
||
fmt.Printf("[INFO] %s\n", message)
|
||
}
|
||
return
|
||
}
|
||
l.Infof(fields, message, args...)
|
||
}
|
||
|
||
// LogWarn 记录警告日志
|
||
// message: 日志消息
|
||
// args: 格式化参数(可选)
|
||
func (f *Factory) LogWarn(message string, args ...interface{}) {
|
||
l, err := f.getLogger()
|
||
if err != nil {
|
||
// 如果日志初始化失败,使用标准输出
|
||
if len(args) > 0 {
|
||
fmt.Printf("[WARN] "+message+"\n", args...)
|
||
} else {
|
||
fmt.Printf("[WARN] %s\n", message)
|
||
}
|
||
return
|
||
}
|
||
if len(args) > 0 {
|
||
l.Warn(message, args...)
|
||
} else {
|
||
l.Warn(message)
|
||
}
|
||
}
|
||
|
||
// LogWarnf 记录警告日志(带字段)
|
||
// fields: 日志字段
|
||
// message: 日志消息
|
||
// args: 格式化参数(可选)
|
||
func (f *Factory) LogWarnf(fields map[string]interface{}, message string, args ...interface{}) {
|
||
l, err := f.getLogger()
|
||
if err != nil {
|
||
// 如果日志初始化失败,使用标准输出
|
||
if len(args) > 0 {
|
||
fmt.Printf("[WARN] "+message+"\n", args...)
|
||
} else {
|
||
fmt.Printf("[WARN] %s\n", message)
|
||
}
|
||
return
|
||
}
|
||
l.Warnf(fields, message, args...)
|
||
}
|
||
|
||
// LogError 记录错误日志
|
||
// message: 日志消息
|
||
// args: 格式化参数(可选)
|
||
func (f *Factory) LogError(message string, args ...interface{}) {
|
||
l, err := f.getLogger()
|
||
if err != nil {
|
||
// 如果日志初始化失败,使用标准输出
|
||
if len(args) > 0 {
|
||
fmt.Printf("[ERROR] "+message+"\n", args...)
|
||
} else {
|
||
fmt.Printf("[ERROR] %s\n", message)
|
||
}
|
||
return
|
||
}
|
||
if len(args) > 0 {
|
||
l.Error(message, args...)
|
||
} else {
|
||
l.Error(message)
|
||
}
|
||
}
|
||
|
||
// LogErrorf 记录错误日志(带字段)
|
||
// fields: 日志字段
|
||
// message: 日志消息
|
||
// args: 格式化参数(可选)
|
||
func (f *Factory) LogErrorf(fields map[string]interface{}, message string, args ...interface{}) {
|
||
l, err := f.getLogger()
|
||
if err != nil {
|
||
// 如果日志初始化失败,使用标准输出
|
||
if len(args) > 0 {
|
||
fmt.Printf("[ERROR] "+message+"\n", args...)
|
||
} else {
|
||
fmt.Printf("[ERROR] %s\n", message)
|
||
}
|
||
return
|
||
}
|
||
l.Errorf(fields, message, args...)
|
||
}
|
||
|
||
// getDatabase 获取数据库连接对象(内部方法,延迟初始化)
|
||
func (f *Factory) getDatabase() (*gorm.DB, error) {
|
||
if f.db != nil {
|
||
return f.db, nil
|
||
}
|
||
|
||
if f.cfg.Database == nil {
|
||
return nil, fmt.Errorf("database config is nil")
|
||
}
|
||
|
||
// 获取DSN
|
||
dsn, err := f.cfg.GetDatabaseDSN()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get DSN: %w", err)
|
||
}
|
||
|
||
// 根据数据库类型创建连接
|
||
var db *gorm.DB
|
||
switch f.cfg.Database.Type {
|
||
case "mysql":
|
||
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
|
||
case "postgres":
|
||
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||
case "sqlite":
|
||
db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||
default:
|
||
return nil, fmt.Errorf("unsupported database type: %s", f.cfg.Database.Type)
|
||
}
|
||
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||
}
|
||
|
||
// 配置连接池
|
||
sqlDB, err := db.DB()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get sql.DB: %w", err)
|
||
}
|
||
|
||
if f.cfg.Database.MaxOpenConns > 0 {
|
||
sqlDB.SetMaxOpenConns(f.cfg.Database.MaxOpenConns)
|
||
}
|
||
if f.cfg.Database.MaxIdleConns > 0 {
|
||
sqlDB.SetMaxIdleConns(f.cfg.Database.MaxIdleConns)
|
||
}
|
||
if f.cfg.Database.ConnMaxLifetime > 0 {
|
||
sqlDB.SetConnMaxLifetime(time.Duration(f.cfg.Database.ConnMaxLifetime) * time.Second)
|
||
}
|
||
|
||
f.db = db
|
||
return db, nil
|
||
}
|
||
|
||
// GetDatabase 获取数据库连接对象(已初始化)
|
||
// 返回已初始化的GORM数据库对象,可直接使用
|
||
// 注意:数据库保持返回GORM对象,因为GORM已经提供了很好的抽象
|
||
func (f *Factory) GetDatabase() (*gorm.DB, error) {
|
||
return f.getDatabase()
|
||
}
|
||
|
||
// getRedisClient 获取Redis客户端对象(内部方法,延迟初始化)
|
||
func (f *Factory) getRedisClient() (*redis.Client, error) {
|
||
if f.redis != nil {
|
||
return f.redis, nil
|
||
}
|
||
|
||
if f.cfg.Redis == nil {
|
||
return nil, fmt.Errorf("redis config is nil")
|
||
}
|
||
|
||
// 获取Redis地址
|
||
addr := f.cfg.GetRedisAddr()
|
||
if addr == "" {
|
||
return nil, fmt.Errorf("redis address is empty")
|
||
}
|
||
|
||
// 设置默认值
|
||
redisConfig := f.cfg.Redis
|
||
if redisConfig.PoolSize == 0 {
|
||
redisConfig.PoolSize = 10 // 默认连接池大小
|
||
}
|
||
if redisConfig.MinIdleConns == 0 {
|
||
redisConfig.MinIdleConns = 5 // 默认最小空闲连接数
|
||
}
|
||
if redisConfig.DialTimeout == 0 {
|
||
redisConfig.DialTimeout = 5 // 默认连接超时5秒
|
||
}
|
||
if redisConfig.ReadTimeout == 0 {
|
||
redisConfig.ReadTimeout = 3 // 默认读取超时3秒
|
||
}
|
||
if redisConfig.WriteTimeout == 0 {
|
||
redisConfig.WriteTimeout = 3 // 默认写入超时3秒
|
||
}
|
||
|
||
// 创建Redis客户端
|
||
client := redis.NewClient(&redis.Options{
|
||
Addr: addr,
|
||
Password: redisConfig.Password,
|
||
DB: redisConfig.Database,
|
||
PoolSize: redisConfig.PoolSize,
|
||
MinIdleConns: redisConfig.MinIdleConns,
|
||
MaxRetries: redisConfig.MaxRetries,
|
||
DialTimeout: time.Duration(redisConfig.DialTimeout) * time.Second,
|
||
ReadTimeout: time.Duration(redisConfig.ReadTimeout) * time.Second,
|
||
WriteTimeout: time.Duration(redisConfig.WriteTimeout) * time.Second,
|
||
})
|
||
|
||
// 测试连接
|
||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(redisConfig.DialTimeout)*time.Second)
|
||
defer cancel()
|
||
|
||
_, err := client.Ping(ctx).Result()
|
||
if err != nil {
|
||
client.Close() // 连接失败时关闭客户端
|
||
return nil, fmt.Errorf("failed to connect to redis: %w", err)
|
||
}
|
||
|
||
f.redis = client
|
||
return client, nil
|
||
}
|
||
|
||
// GetRedisClient 获取Redis客户端对象(已初始化)
|
||
// 返回已初始化的Redis客户端对象,可直接使用
|
||
// 注意:推荐使用 RedisGet、RedisSet、RedisDelete 等方法直接操作Redis
|
||
// 如果需要使用Redis的高级功能(如Hash、List、Set等),可以使用此方法获取客户端对象
|
||
func (f *Factory) GetRedisClient() (*redis.Client, error) {
|
||
return f.getRedisClient()
|
||
}
|
||
|
||
// RedisGet 获取Redis值(黑盒模式)
|
||
// key: Redis键
|
||
func (f *Factory) RedisGet(ctx context.Context, key string) (string, error) {
|
||
client, err := f.getRedisClient()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
result, err := client.Get(ctx, key).Result()
|
||
if err == redis.Nil {
|
||
return "", nil // key不存在,返回空字符串
|
||
}
|
||
if err != nil {
|
||
return "", fmt.Errorf("failed to get redis key: %w", err)
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// RedisSet 设置Redis值(黑盒模式)
|
||
// key: Redis键
|
||
// value: Redis值
|
||
// expiration: 过期时间(可选,0表示不过期)
|
||
func (f *Factory) RedisSet(ctx context.Context, key string, value interface{}, expiration ...time.Duration) error {
|
||
client, err := f.getRedisClient()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
var exp time.Duration
|
||
if len(expiration) > 0 {
|
||
exp = expiration[0]
|
||
}
|
||
|
||
err = client.Set(ctx, key, value, exp).Err()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to set redis key: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// RedisDelete 删除Redis键(黑盒模式)
|
||
// keys: Redis键列表
|
||
func (f *Factory) RedisDelete(ctx context.Context, keys ...string) error {
|
||
if len(keys) == 0 {
|
||
return nil
|
||
}
|
||
|
||
client, err := f.getRedisClient()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
err = client.Del(ctx, keys...).Err()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to delete redis keys: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// RedisExists 检查Redis键是否存在(黑盒模式)
|
||
// key: Redis键
|
||
func (f *Factory) RedisExists(ctx context.Context, key string) (bool, error) {
|
||
client, err := f.getRedisClient()
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
|
||
count, err := client.Exists(ctx, key).Result()
|
||
if err != nil {
|
||
return false, fmt.Errorf("failed to check redis key existence: %w", err)
|
||
}
|
||
|
||
return count > 0, nil
|
||
}
|
||
|
||
// GetConfig 获取配置对象
|
||
func (f *Factory) GetConfig() *config.Config {
|
||
return f.cfg
|
||
}
|
||
|
||
// getStorage 获取存储实例(内部方法,延迟初始化)
|
||
func (f *Factory) getStorage() (storage.Storage, error) {
|
||
if f.storage != nil {
|
||
return f.storage, nil
|
||
}
|
||
|
||
// 根据配置自动选择存储类型
|
||
// 优先级:MinIO > OSS
|
||
var storageType storage.StorageType
|
||
if f.cfg.MinIO != nil {
|
||
storageType = storage.StorageTypeMinIO
|
||
} else if f.cfg.OSS != nil {
|
||
storageType = storage.StorageTypeOSS
|
||
} else {
|
||
return nil, fmt.Errorf("no storage config found (OSS or MinIO)")
|
||
}
|
||
|
||
// 创建存储实例
|
||
s, err := storage.NewStorage(storageType, f.cfg)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create storage: %w", err)
|
||
}
|
||
|
||
f.storage = s
|
||
return s, nil
|
||
}
|
||
|
||
// UploadFile 上传文件
|
||
// ctx: 上下文
|
||
// objectKey: 对象键(文件路径)
|
||
// reader: 文件内容
|
||
// contentType: 文件类型(可选)
|
||
// 返回文件访问URL和错误
|
||
func (f *Factory) UploadFile(ctx context.Context, objectKey string, reader io.Reader, contentType ...string) (string, error) {
|
||
s, err := f.getStorage()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
// 上传文件
|
||
err = s.Upload(ctx, objectKey, reader, contentType...)
|
||
if err != nil {
|
||
return "", fmt.Errorf("failed to upload file: %w", err)
|
||
}
|
||
|
||
// 获取文件URL
|
||
url, err := s.GetURL(objectKey, 0)
|
||
if err != nil {
|
||
return "", fmt.Errorf("failed to get file URL: %w", err)
|
||
}
|
||
|
||
return url, nil
|
||
}
|
||
|
||
// GetFileURL 获取文件访问URL(Show方法)
|
||
// objectKey: 对象键
|
||
// expires: 过期时间(秒),0表示永久有效
|
||
func (f *Factory) GetFileURL(objectKey string, expires int64) (string, error) {
|
||
s, err := f.getStorage()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
return s.GetURL(objectKey, expires)
|
||
}
|