Files
go-common/factory/factory.go

186 lines
5.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package factory
import (
"context"
"fmt"
"time"
"git.toowon.com/jimmy/go-common/config"
"git.toowon.com/jimmy/go-common/email"
"git.toowon.com/jimmy/go-common/logger"
"git.toowon.com/jimmy/go-common/sms"
"github.com/redis/go-redis/v9"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// Factory 工厂类,用于从配置创建各种客户端对象
type Factory struct {
cfg *config.Config
}
// NewFactory 创建工厂实例
func NewFactory(cfg *config.Config) *Factory {
return &Factory{
cfg: cfg,
}
}
// NewFactoryFromFile 从配置文件创建工厂实例(便捷方法)
// filePath: 配置文件路径
func NewFactoryFromFile(filePath string) (*Factory, error) {
cfg, err := config.LoadFromFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to load config: %w", err)
}
return NewFactory(cfg), nil
}
// GetEmailClient 获取邮件客户端(已初始化)
// 返回已初始化的邮件客户端对象,可直接使用
func (f *Factory) GetEmailClient() (*email.Email, error) {
if f.cfg.Email == nil {
return nil, fmt.Errorf("email config is nil")
}
return email.NewEmail(f.cfg.Email)
}
// GetSMSClient 获取短信客户端(已初始化)
// 返回已初始化的短信客户端对象,可直接使用
func (f *Factory) GetSMSClient() (*sms.SMS, error) {
if f.cfg.SMS == nil {
return nil, fmt.Errorf("SMS config is nil")
}
return sms.NewSMS(f.cfg.SMS)
}
// GetLogger 获取日志记录器(已初始化)
// 返回已初始化的日志记录器对象,可直接使用
func (f *Factory) GetLogger() (*logger.Logger, error) {
if f.cfg.Logger == nil {
// 如果没有配置,使用默认配置创建
return logger.NewLogger(nil)
}
return logger.NewLogger(f.cfg.Logger)
}
// GetDatabase 获取数据库连接对象(已初始化)
// 返回已初始化的GORM数据库对象可直接使用
func (f *Factory) GetDatabase() (*gorm.DB, error) {
if f.cfg.Database == nil {
return nil, fmt.Errorf("database config is nil")
}
// 获取DSN
dsn, err := f.cfg.GetDatabaseDSN()
if err != nil {
return nil, fmt.Errorf("failed to get DSN: %w", err)
}
// 根据数据库类型创建连接
var db *gorm.DB
switch f.cfg.Database.Type {
case "mysql":
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
case "postgres":
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
case "sqlite":
db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{})
default:
return nil, fmt.Errorf("unsupported database type: %s", f.cfg.Database.Type)
}
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
// 配置连接池
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("failed to get sql.DB: %w", err)
}
if f.cfg.Database.MaxOpenConns > 0 {
sqlDB.SetMaxOpenConns(f.cfg.Database.MaxOpenConns)
}
if f.cfg.Database.MaxIdleConns > 0 {
sqlDB.SetMaxIdleConns(f.cfg.Database.MaxIdleConns)
}
if f.cfg.Database.ConnMaxLifetime > 0 {
sqlDB.SetConnMaxLifetime(time.Duration(f.cfg.Database.ConnMaxLifetime) * time.Second)
}
return db, nil
}
// GetRedisClient 获取Redis客户端对象已初始化
// 返回已初始化的Redis客户端对象可直接使用
func (f *Factory) GetRedisClient() (*redis.Client, error) {
if f.cfg.Redis == nil {
return nil, fmt.Errorf("redis config is nil")
}
// 获取Redis地址
addr := f.cfg.GetRedisAddr()
if addr == "" {
return nil, fmt.Errorf("redis address is empty")
}
// 设置默认值
redisConfig := f.cfg.Redis
if redisConfig.PoolSize == 0 {
redisConfig.PoolSize = 10 // 默认连接池大小
}
if redisConfig.MinIdleConns == 0 {
redisConfig.MinIdleConns = 5 // 默认最小空闲连接数
}
if redisConfig.DialTimeout == 0 {
redisConfig.DialTimeout = 5 // 默认连接超时5秒
}
if redisConfig.ReadTimeout == 0 {
redisConfig.ReadTimeout = 3 // 默认读取超时3秒
}
if redisConfig.WriteTimeout == 0 {
redisConfig.WriteTimeout = 3 // 默认写入超时3秒
}
// 创建Redis客户端
client := redis.NewClient(&redis.Options{
Addr: addr,
Password: redisConfig.Password,
DB: redisConfig.Database,
PoolSize: redisConfig.PoolSize,
MinIdleConns: redisConfig.MinIdleConns,
MaxRetries: redisConfig.MaxRetries,
DialTimeout: time.Duration(redisConfig.DialTimeout) * time.Second,
ReadTimeout: time.Duration(redisConfig.ReadTimeout) * time.Second,
WriteTimeout: time.Duration(redisConfig.WriteTimeout) * time.Second,
})
// 测试连接
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(redisConfig.DialTimeout)*time.Second)
defer cancel()
_, err := client.Ping(ctx).Result()
if err != nil {
client.Close() // 连接失败时关闭客户端
return nil, fmt.Errorf("failed to connect to redis: %w", err)
}
return client, nil
}
// GetRedisConfig 获取Redis配置用于创建Redis客户端
// 返回Redis配置对象调用方可以使用此配置创建Redis客户端
// 注意:推荐使用 GetRedisClient 方法直接获取已初始化的客户端
func (f *Factory) GetRedisConfig() *config.RedisConfig {
return f.cfg.Redis
}
// GetConfig 获取配置对象
func (f *Factory) GetConfig() *config.Config {
return f.cfg
}