186 lines
5.2 KiB
Go
186 lines
5.2 KiB
Go
package factory
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"time"
|
||
|
||
"git.toowon.com/jimmy/go-common/config"
|
||
"git.toowon.com/jimmy/go-common/email"
|
||
"git.toowon.com/jimmy/go-common/logger"
|
||
"git.toowon.com/jimmy/go-common/sms"
|
||
"github.com/redis/go-redis/v9"
|
||
"gorm.io/driver/mysql"
|
||
"gorm.io/driver/postgres"
|
||
"gorm.io/driver/sqlite"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// Factory 工厂类,用于从配置创建各种客户端对象
|
||
type Factory struct {
|
||
cfg *config.Config
|
||
}
|
||
|
||
// NewFactory 创建工厂实例
|
||
func NewFactory(cfg *config.Config) *Factory {
|
||
return &Factory{
|
||
cfg: cfg,
|
||
}
|
||
}
|
||
|
||
// NewFactoryFromFile 从配置文件创建工厂实例(便捷方法)
|
||
// filePath: 配置文件路径
|
||
func NewFactoryFromFile(filePath string) (*Factory, error) {
|
||
cfg, err := config.LoadFromFile(filePath)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to load config: %w", err)
|
||
}
|
||
return NewFactory(cfg), nil
|
||
}
|
||
|
||
// GetEmailClient 获取邮件客户端(已初始化)
|
||
// 返回已初始化的邮件客户端对象,可直接使用
|
||
func (f *Factory) GetEmailClient() (*email.Email, error) {
|
||
if f.cfg.Email == nil {
|
||
return nil, fmt.Errorf("email config is nil")
|
||
}
|
||
return email.NewEmail(f.cfg.Email)
|
||
}
|
||
|
||
// GetSMSClient 获取短信客户端(已初始化)
|
||
// 返回已初始化的短信客户端对象,可直接使用
|
||
func (f *Factory) GetSMSClient() (*sms.SMS, error) {
|
||
if f.cfg.SMS == nil {
|
||
return nil, fmt.Errorf("SMS config is nil")
|
||
}
|
||
return sms.NewSMS(f.cfg.SMS)
|
||
}
|
||
|
||
// GetLogger 获取日志记录器(已初始化)
|
||
// 返回已初始化的日志记录器对象,可直接使用
|
||
func (f *Factory) GetLogger() (*logger.Logger, error) {
|
||
if f.cfg.Logger == nil {
|
||
// 如果没有配置,使用默认配置创建
|
||
return logger.NewLogger(nil)
|
||
}
|
||
return logger.NewLogger(f.cfg.Logger)
|
||
}
|
||
|
||
// GetDatabase 获取数据库连接对象(已初始化)
|
||
// 返回已初始化的GORM数据库对象,可直接使用
|
||
func (f *Factory) GetDatabase() (*gorm.DB, error) {
|
||
if f.cfg.Database == nil {
|
||
return nil, fmt.Errorf("database config is nil")
|
||
}
|
||
|
||
// 获取DSN
|
||
dsn, err := f.cfg.GetDatabaseDSN()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get DSN: %w", err)
|
||
}
|
||
|
||
// 根据数据库类型创建连接
|
||
var db *gorm.DB
|
||
switch f.cfg.Database.Type {
|
||
case "mysql":
|
||
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
|
||
case "postgres":
|
||
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||
case "sqlite":
|
||
db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||
default:
|
||
return nil, fmt.Errorf("unsupported database type: %s", f.cfg.Database.Type)
|
||
}
|
||
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||
}
|
||
|
||
// 配置连接池
|
||
sqlDB, err := db.DB()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get sql.DB: %w", err)
|
||
}
|
||
|
||
if f.cfg.Database.MaxOpenConns > 0 {
|
||
sqlDB.SetMaxOpenConns(f.cfg.Database.MaxOpenConns)
|
||
}
|
||
if f.cfg.Database.MaxIdleConns > 0 {
|
||
sqlDB.SetMaxIdleConns(f.cfg.Database.MaxIdleConns)
|
||
}
|
||
if f.cfg.Database.ConnMaxLifetime > 0 {
|
||
sqlDB.SetConnMaxLifetime(time.Duration(f.cfg.Database.ConnMaxLifetime) * time.Second)
|
||
}
|
||
|
||
return db, nil
|
||
}
|
||
|
||
// GetRedisClient 获取Redis客户端对象(已初始化)
|
||
// 返回已初始化的Redis客户端对象,可直接使用
|
||
func (f *Factory) GetRedisClient() (*redis.Client, error) {
|
||
if f.cfg.Redis == nil {
|
||
return nil, fmt.Errorf("redis config is nil")
|
||
}
|
||
|
||
// 获取Redis地址
|
||
addr := f.cfg.GetRedisAddr()
|
||
if addr == "" {
|
||
return nil, fmt.Errorf("redis address is empty")
|
||
}
|
||
|
||
// 设置默认值
|
||
redisConfig := f.cfg.Redis
|
||
if redisConfig.PoolSize == 0 {
|
||
redisConfig.PoolSize = 10 // 默认连接池大小
|
||
}
|
||
if redisConfig.MinIdleConns == 0 {
|
||
redisConfig.MinIdleConns = 5 // 默认最小空闲连接数
|
||
}
|
||
if redisConfig.DialTimeout == 0 {
|
||
redisConfig.DialTimeout = 5 // 默认连接超时5秒
|
||
}
|
||
if redisConfig.ReadTimeout == 0 {
|
||
redisConfig.ReadTimeout = 3 // 默认读取超时3秒
|
||
}
|
||
if redisConfig.WriteTimeout == 0 {
|
||
redisConfig.WriteTimeout = 3 // 默认写入超时3秒
|
||
}
|
||
|
||
// 创建Redis客户端
|
||
client := redis.NewClient(&redis.Options{
|
||
Addr: addr,
|
||
Password: redisConfig.Password,
|
||
DB: redisConfig.Database,
|
||
PoolSize: redisConfig.PoolSize,
|
||
MinIdleConns: redisConfig.MinIdleConns,
|
||
MaxRetries: redisConfig.MaxRetries,
|
||
DialTimeout: time.Duration(redisConfig.DialTimeout) * time.Second,
|
||
ReadTimeout: time.Duration(redisConfig.ReadTimeout) * time.Second,
|
||
WriteTimeout: time.Duration(redisConfig.WriteTimeout) * time.Second,
|
||
})
|
||
|
||
// 测试连接
|
||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(redisConfig.DialTimeout)*time.Second)
|
||
defer cancel()
|
||
|
||
_, err := client.Ping(ctx).Result()
|
||
if err != nil {
|
||
client.Close() // 连接失败时关闭客户端
|
||
return nil, fmt.Errorf("failed to connect to redis: %w", err)
|
||
}
|
||
|
||
return client, nil
|
||
}
|
||
|
||
// GetRedisConfig 获取Redis配置(用于创建Redis客户端)
|
||
// 返回Redis配置对象,调用方可以使用此配置创建Redis客户端
|
||
// 注意:推荐使用 GetRedisClient 方法直接获取已初始化的客户端
|
||
func (f *Factory) GetRedisConfig() *config.RedisConfig {
|
||
return f.cfg.Redis
|
||
}
|
||
|
||
// GetConfig 获取配置对象
|
||
func (f *Factory) GetConfig() *config.Config {
|
||
return f.cfg
|
||
}
|