修改对象的获取方式
This commit is contained in:
@@ -2,11 +2,16 @@ package factory
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"git.toowon.com/jimmy/go-common/config"
|
||||
"git.toowon.com/jimmy/go-common/email"
|
||||
"git.toowon.com/jimmy/go-common/logger"
|
||||
"git.toowon.com/jimmy/go-common/sms"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Factory 工厂类,用于从配置创建各种客户端对象
|
||||
@@ -21,6 +26,16 @@ func NewFactory(cfg *config.Config) *Factory {
|
||||
}
|
||||
}
|
||||
|
||||
// NewFactoryFromFile 从配置文件创建工厂实例(便捷方法)
|
||||
// filePath: 配置文件路径
|
||||
func NewFactoryFromFile(filePath string) (*Factory, error) {
|
||||
cfg, err := config.LoadFromFile(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
return NewFactory(cfg), nil
|
||||
}
|
||||
|
||||
// GetEmailClient 获取邮件客户端(已初始化)
|
||||
// 返回已初始化的邮件客户端对象,可直接使用
|
||||
func (f *Factory) GetEmailClient() (*email.Email, error) {
|
||||
@@ -49,8 +64,63 @@ func (f *Factory) GetLogger() (*logger.Logger, error) {
|
||||
return logger.NewLogger(f.cfg.Logger)
|
||||
}
|
||||
|
||||
// GetDatabase 获取数据库连接对象(已初始化)
|
||||
// 返回已初始化的GORM数据库对象,可直接使用
|
||||
func (f *Factory) GetDatabase() (*gorm.DB, error) {
|
||||
if f.cfg.Database == nil {
|
||||
return nil, fmt.Errorf("database config is nil")
|
||||
}
|
||||
|
||||
// 获取DSN
|
||||
dsn, err := f.cfg.GetDatabaseDSN()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get DSN: %w", err)
|
||||
}
|
||||
|
||||
// 根据数据库类型创建连接
|
||||
var db *gorm.DB
|
||||
switch f.cfg.Database.Type {
|
||||
case "mysql":
|
||||
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
|
||||
case "postgres":
|
||||
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||||
case "sqlite":
|
||||
db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported database type: %s", f.cfg.Database.Type)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
|
||||
// 配置连接池
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get sql.DB: %w", err)
|
||||
}
|
||||
|
||||
if f.cfg.Database.MaxOpenConns > 0 {
|
||||
sqlDB.SetMaxOpenConns(f.cfg.Database.MaxOpenConns)
|
||||
}
|
||||
if f.cfg.Database.MaxIdleConns > 0 {
|
||||
sqlDB.SetMaxIdleConns(f.cfg.Database.MaxIdleConns)
|
||||
}
|
||||
if f.cfg.Database.ConnMaxLifetime > 0 {
|
||||
sqlDB.SetConnMaxLifetime(time.Duration(f.cfg.Database.ConnMaxLifetime) * time.Second)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// GetRedisConfig 获取Redis配置(用于创建Redis客户端)
|
||||
// 返回Redis配置对象,调用方可以使用此配置创建Redis客户端
|
||||
// 注意:Go标准库没有Redis客户端,需要调用方使用第三方库(如go-redis/redis)创建
|
||||
func (f *Factory) GetRedisConfig() *config.RedisConfig {
|
||||
return f.cfg.Redis
|
||||
}
|
||||
|
||||
// GetConfig 获取配置对象
|
||||
func (f *Factory) GetConfig() *config.Config {
|
||||
return f.cfg
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user