Files
go-common/factory/factory.go

581 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package factory
import (
"context"
"fmt"
"io"
"time"
"git.toowon.com/jimmy/go-common/config"
"git.toowon.com/jimmy/go-common/email"
"git.toowon.com/jimmy/go-common/logger"
"git.toowon.com/jimmy/go-common/sms"
"git.toowon.com/jimmy/go-common/storage"
"github.com/redis/go-redis/v9"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// Factory 工厂类,用于从配置创建各种客户端对象
type Factory struct {
cfg *config.Config
storage storage.Storage // 存储实例(延迟初始化)
logger *logger.Logger // 日志实例(延迟初始化)
email *email.Email // 邮件客户端(延迟初始化)
sms *sms.SMS // 短信客户端(延迟初始化)
db *gorm.DB // 数据库连接(延迟初始化)
redis *redis.Client // Redis客户端延迟初始化
}
// NewFactory 创建工厂实例
func NewFactory(cfg *config.Config) *Factory {
return &Factory{
cfg: cfg,
}
}
// NewFactoryFromFile 从配置文件创建工厂实例(便捷方法)
// filePath: 配置文件路径
func NewFactoryFromFile(filePath string) (*Factory, error) {
cfg, err := config.LoadFromFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to load config: %w", err)
}
return NewFactory(cfg), nil
}
// getEmailClient 获取邮件客户端(内部方法,延迟初始化)
func (f *Factory) getEmailClient() (*email.Email, error) {
if f.email != nil {
return f.email, nil
}
if f.cfg.Email == nil {
return nil, fmt.Errorf("email config is nil")
}
e, err := email.NewEmail(f.cfg.Email)
if err != nil {
return nil, fmt.Errorf("failed to create email client: %w", err)
}
f.email = e
return e, nil
}
// SendEmail 发送邮件(黑盒模式)
// to: 收件人列表
// subject: 邮件主题
// body: 邮件正文(纯文本)
// htmlBody: HTML正文可选如果设置了会优先使用
func (f *Factory) SendEmail(to []string, subject, body string, htmlBody ...string) error {
e, err := f.getEmailClient()
if err != nil {
return err
}
msg := &email.Message{
To: to,
Subject: subject,
Body: body,
}
if len(htmlBody) > 0 && htmlBody[0] != "" {
msg.HTMLBody = htmlBody[0]
}
return e.Send(msg)
}
// getSMSClient 获取短信客户端(内部方法,延迟初始化)
func (f *Factory) getSMSClient() (*sms.SMS, error) {
if f.sms != nil {
return f.sms, nil
}
if f.cfg.SMS == nil {
return nil, fmt.Errorf("SMS config is nil")
}
s, err := sms.NewSMS(f.cfg.SMS)
if err != nil {
return nil, fmt.Errorf("failed to create SMS client: %w", err)
}
f.sms = s
return s, nil
}
// SendSMS 发送短信(黑盒模式)
// phoneNumbers: 手机号列表
// templateParam: 模板参数map或JSON字符串
// templateCode: 模板代码(可选,如果为空使用配置中的模板代码)
func (f *Factory) SendSMS(phoneNumbers []string, templateParam interface{}, templateCode ...string) (*sms.SendResponse, error) {
s, err := f.getSMSClient()
if err != nil {
return nil, err
}
req := &sms.SendRequest{
PhoneNumbers: phoneNumbers,
TemplateParam: templateParam,
}
if len(templateCode) > 0 && templateCode[0] != "" {
req.TemplateCode = templateCode[0]
}
return s.Send(req)
}
// getLogger 获取日志记录器(内部方法,延迟初始化)
func (f *Factory) getLogger() (*logger.Logger, error) {
if f.logger != nil {
return f.logger, nil
}
var l *logger.Logger
var err error
if f.cfg.Logger == nil {
// 如果没有配置,使用默认配置创建
l, err = logger.NewLogger(nil)
} else {
l, err = logger.NewLogger(f.cfg.Logger)
}
if err != nil {
return nil, fmt.Errorf("failed to create logger: %w", err)
}
f.logger = l
return l, nil
}
// LogDebug 记录调试日志
// message: 日志消息
// args: 格式化参数(可选)
func (f *Factory) LogDebug(message string, args ...interface{}) {
l, err := f.getLogger()
if err != nil {
// 如果日志初始化失败,使用标准输出
if len(args) > 0 {
fmt.Printf("[DEBUG] "+message+"\n", args...)
} else {
fmt.Printf("[DEBUG] %s\n", message)
}
return
}
if len(args) > 0 {
l.Debug(message, args...)
} else {
l.Debug(message)
}
}
// LogDebugf 记录调试日志(带字段)
// fields: 日志字段
// message: 日志消息
// args: 格式化参数(可选)
func (f *Factory) LogDebugf(fields map[string]interface{}, message string, args ...interface{}) {
l, err := f.getLogger()
if err != nil {
// 如果日志初始化失败,使用标准输出
if len(args) > 0 {
fmt.Printf("[DEBUG] "+message+"\n", args...)
} else {
fmt.Printf("[DEBUG] %s\n", message)
}
return
}
l.Debugf(fields, message, args...)
}
// LogInfo 记录信息日志
// message: 日志消息
// args: 格式化参数(可选)
func (f *Factory) LogInfo(message string, args ...interface{}) {
l, err := f.getLogger()
if err != nil {
// 如果日志初始化失败,使用标准输出
if len(args) > 0 {
fmt.Printf("[INFO] "+message+"\n", args...)
} else {
fmt.Printf("[INFO] %s\n", message)
}
return
}
if len(args) > 0 {
l.Info(message, args...)
} else {
l.Info(message)
}
}
// LogInfof 记录信息日志(带字段)
// fields: 日志字段
// message: 日志消息
// args: 格式化参数(可选)
func (f *Factory) LogInfof(fields map[string]interface{}, message string, args ...interface{}) {
l, err := f.getLogger()
if err != nil {
// 如果日志初始化失败,使用标准输出
if len(args) > 0 {
fmt.Printf("[INFO] "+message+"\n", args...)
} else {
fmt.Printf("[INFO] %s\n", message)
}
return
}
l.Infof(fields, message, args...)
}
// LogWarn 记录警告日志
// message: 日志消息
// args: 格式化参数(可选)
func (f *Factory) LogWarn(message string, args ...interface{}) {
l, err := f.getLogger()
if err != nil {
// 如果日志初始化失败,使用标准输出
if len(args) > 0 {
fmt.Printf("[WARN] "+message+"\n", args...)
} else {
fmt.Printf("[WARN] %s\n", message)
}
return
}
if len(args) > 0 {
l.Warn(message, args...)
} else {
l.Warn(message)
}
}
// LogWarnf 记录警告日志(带字段)
// fields: 日志字段
// message: 日志消息
// args: 格式化参数(可选)
func (f *Factory) LogWarnf(fields map[string]interface{}, message string, args ...interface{}) {
l, err := f.getLogger()
if err != nil {
// 如果日志初始化失败,使用标准输出
if len(args) > 0 {
fmt.Printf("[WARN] "+message+"\n", args...)
} else {
fmt.Printf("[WARN] %s\n", message)
}
return
}
l.Warnf(fields, message, args...)
}
// LogError 记录错误日志
// message: 日志消息
// args: 格式化参数(可选)
func (f *Factory) LogError(message string, args ...interface{}) {
l, err := f.getLogger()
if err != nil {
// 如果日志初始化失败,使用标准输出
if len(args) > 0 {
fmt.Printf("[ERROR] "+message+"\n", args...)
} else {
fmt.Printf("[ERROR] %s\n", message)
}
return
}
if len(args) > 0 {
l.Error(message, args...)
} else {
l.Error(message)
}
}
// LogErrorf 记录错误日志(带字段)
// fields: 日志字段
// message: 日志消息
// args: 格式化参数(可选)
func (f *Factory) LogErrorf(fields map[string]interface{}, message string, args ...interface{}) {
l, err := f.getLogger()
if err != nil {
// 如果日志初始化失败,使用标准输出
if len(args) > 0 {
fmt.Printf("[ERROR] "+message+"\n", args...)
} else {
fmt.Printf("[ERROR] %s\n", message)
}
return
}
l.Errorf(fields, message, args...)
}
// getDatabase 获取数据库连接对象(内部方法,延迟初始化)
func (f *Factory) getDatabase() (*gorm.DB, error) {
if f.db != nil {
return f.db, nil
}
if f.cfg.Database == nil {
return nil, fmt.Errorf("database config is nil")
}
// 获取DSN
dsn, err := f.cfg.GetDatabaseDSN()
if err != nil {
return nil, fmt.Errorf("failed to get DSN: %w", err)
}
// 根据数据库类型创建连接
var db *gorm.DB
switch f.cfg.Database.Type {
case "mysql":
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
case "postgres":
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
case "sqlite":
db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{})
default:
return nil, fmt.Errorf("unsupported database type: %s", f.cfg.Database.Type)
}
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
// 配置连接池
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("failed to get sql.DB: %w", err)
}
if f.cfg.Database.MaxOpenConns > 0 {
sqlDB.SetMaxOpenConns(f.cfg.Database.MaxOpenConns)
}
if f.cfg.Database.MaxIdleConns > 0 {
sqlDB.SetMaxIdleConns(f.cfg.Database.MaxIdleConns)
}
if f.cfg.Database.ConnMaxLifetime > 0 {
sqlDB.SetConnMaxLifetime(time.Duration(f.cfg.Database.ConnMaxLifetime) * time.Second)
}
f.db = db
return db, nil
}
// GetDatabase 获取数据库连接对象(已初始化)
// 返回已初始化的GORM数据库对象可直接使用
// 注意数据库保持返回GORM对象因为GORM已经提供了很好的抽象
func (f *Factory) GetDatabase() (*gorm.DB, error) {
return f.getDatabase()
}
// getRedisClient 获取Redis客户端对象内部方法延迟初始化
func (f *Factory) getRedisClient() (*redis.Client, error) {
if f.redis != nil {
return f.redis, nil
}
if f.cfg.Redis == nil {
return nil, fmt.Errorf("redis config is nil")
}
// 获取Redis地址
addr := f.cfg.GetRedisAddr()
if addr == "" {
return nil, fmt.Errorf("redis address is empty")
}
// 设置默认值
redisConfig := f.cfg.Redis
if redisConfig.PoolSize == 0 {
redisConfig.PoolSize = 10 // 默认连接池大小
}
if redisConfig.MinIdleConns == 0 {
redisConfig.MinIdleConns = 5 // 默认最小空闲连接数
}
if redisConfig.DialTimeout == 0 {
redisConfig.DialTimeout = 5 // 默认连接超时5秒
}
if redisConfig.ReadTimeout == 0 {
redisConfig.ReadTimeout = 3 // 默认读取超时3秒
}
if redisConfig.WriteTimeout == 0 {
redisConfig.WriteTimeout = 3 // 默认写入超时3秒
}
// 创建Redis客户端
client := redis.NewClient(&redis.Options{
Addr: addr,
Password: redisConfig.Password,
DB: redisConfig.Database,
PoolSize: redisConfig.PoolSize,
MinIdleConns: redisConfig.MinIdleConns,
MaxRetries: redisConfig.MaxRetries,
DialTimeout: time.Duration(redisConfig.DialTimeout) * time.Second,
ReadTimeout: time.Duration(redisConfig.ReadTimeout) * time.Second,
WriteTimeout: time.Duration(redisConfig.WriteTimeout) * time.Second,
})
// 测试连接
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(redisConfig.DialTimeout)*time.Second)
defer cancel()
_, err := client.Ping(ctx).Result()
if err != nil {
client.Close() // 连接失败时关闭客户端
return nil, fmt.Errorf("failed to connect to redis: %w", err)
}
f.redis = client
return client, nil
}
// RedisGet 获取Redis值黑盒模式
// key: Redis键
func (f *Factory) RedisGet(ctx context.Context, key string) (string, error) {
client, err := f.getRedisClient()
if err != nil {
return "", err
}
result, err := client.Get(ctx, key).Result()
if err == redis.Nil {
return "", nil // key不存在返回空字符串
}
if err != nil {
return "", fmt.Errorf("failed to get redis key: %w", err)
}
return result, nil
}
// RedisSet 设置Redis值黑盒模式
// key: Redis键
// value: Redis值
// expiration: 过期时间可选0表示不过期
func (f *Factory) RedisSet(ctx context.Context, key string, value interface{}, expiration ...time.Duration) error {
client, err := f.getRedisClient()
if err != nil {
return err
}
var exp time.Duration
if len(expiration) > 0 {
exp = expiration[0]
}
err = client.Set(ctx, key, value, exp).Err()
if err != nil {
return fmt.Errorf("failed to set redis key: %w", err)
}
return nil
}
// RedisDelete 删除Redis键黑盒模式
// keys: Redis键列表
func (f *Factory) RedisDelete(ctx context.Context, keys ...string) error {
if len(keys) == 0 {
return nil
}
client, err := f.getRedisClient()
if err != nil {
return err
}
err = client.Del(ctx, keys...).Err()
if err != nil {
return fmt.Errorf("failed to delete redis keys: %w", err)
}
return nil
}
// RedisExists 检查Redis键是否存在黑盒模式
// key: Redis键
func (f *Factory) RedisExists(ctx context.Context, key string) (bool, error) {
client, err := f.getRedisClient()
if err != nil {
return false, err
}
count, err := client.Exists(ctx, key).Result()
if err != nil {
return false, fmt.Errorf("failed to check redis key existence: %w", err)
}
return count > 0, nil
}
// GetConfig 获取配置对象
func (f *Factory) GetConfig() *config.Config {
return f.cfg
}
// getStorage 获取存储实例(内部方法,延迟初始化)
func (f *Factory) getStorage() (storage.Storage, error) {
if f.storage != nil {
return f.storage, nil
}
// 根据配置自动选择存储类型
// 优先级MinIO > OSS
var storageType storage.StorageType
if f.cfg.MinIO != nil {
storageType = storage.StorageTypeMinIO
} else if f.cfg.OSS != nil {
storageType = storage.StorageTypeOSS
} else {
return nil, fmt.Errorf("no storage config found (OSS or MinIO)")
}
// 创建存储实例
s, err := storage.NewStorage(storageType, f.cfg)
if err != nil {
return nil, fmt.Errorf("failed to create storage: %w", err)
}
f.storage = s
return s, nil
}
// UploadFile 上传文件
// ctx: 上下文
// objectKey: 对象键(文件路径)
// reader: 文件内容
// contentType: 文件类型(可选)
// 返回文件访问URL和错误
func (f *Factory) UploadFile(ctx context.Context, objectKey string, reader io.Reader, contentType ...string) (string, error) {
s, err := f.getStorage()
if err != nil {
return "", err
}
// 上传文件
err = s.Upload(ctx, objectKey, reader, contentType...)
if err != nil {
return "", fmt.Errorf("failed to upload file: %w", err)
}
// 获取文件URL
url, err := s.GetURL(objectKey, 0)
if err != nil {
return "", fmt.Errorf("failed to get file URL: %w", err)
}
return url, nil
}
// GetFileURL 获取文件访问URLShow方法
// objectKey: 对象键
// expires: 过期时间0表示永久有效
func (f *Factory) GetFileURL(objectKey string, expires int64) (string, error) {
s, err := f.getStorage()
if err != nil {
return "", err
}
return s.GetURL(objectKey, expires)
}