499 lines
12 KiB
Go
499 lines
12 KiB
Go
package config
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"os"
|
||
"path/filepath"
|
||
)
|
||
|
||
// Config 应用配置
|
||
type Config struct {
|
||
Database *DatabaseConfig `json:"database"`
|
||
OSS *OSSConfig `json:"oss"`
|
||
Redis *RedisConfig `json:"redis"`
|
||
CORS *CORSConfig `json:"cors"`
|
||
MinIO *MinIOConfig `json:"minio"`
|
||
Email *EmailConfig `json:"email"`
|
||
SMS *SMSConfig `json:"sms"`
|
||
Logger *LoggerConfig `json:"logger"`
|
||
}
|
||
|
||
// DatabaseConfig 数据库配置
|
||
type DatabaseConfig struct {
|
||
// Type 数据库类型: mysql, postgres, sqlite
|
||
Type string `json:"type"`
|
||
|
||
// Host 数据库主机
|
||
Host string `json:"host"`
|
||
|
||
// Port 数据库端口
|
||
Port int `json:"port"`
|
||
|
||
// User 数据库用户名
|
||
User string `json:"user"`
|
||
|
||
// Password 数据库密码
|
||
Password string `json:"password"`
|
||
|
||
// Database 数据库名称
|
||
Database string `json:"database"`
|
||
|
||
// Charset 字符集(MySQL使用)
|
||
Charset string `json:"charset"`
|
||
|
||
// MaxOpenConns 最大打开连接数
|
||
MaxOpenConns int `json:"maxOpenConns"`
|
||
|
||
// MaxIdleConns 最大空闲连接数
|
||
MaxIdleConns int `json:"maxIdleConns"`
|
||
|
||
// ConnMaxLifetime 连接最大生存时间(秒)
|
||
ConnMaxLifetime int `json:"connMaxLifetime"`
|
||
|
||
// DSN 数据库连接字符串(如果设置了,会优先使用)
|
||
DSN string `json:"dsn"`
|
||
}
|
||
|
||
// OSSConfig OSS对象存储配置
|
||
type OSSConfig struct {
|
||
// Provider 提供商: aliyun, tencent, aws, qiniu
|
||
Provider string `json:"provider"`
|
||
|
||
// Endpoint 端点地址
|
||
Endpoint string `json:"endpoint"`
|
||
|
||
// AccessKeyID 访问密钥ID
|
||
AccessKeyID string `json:"accessKeyId"`
|
||
|
||
// AccessKeySecret 访问密钥
|
||
AccessKeySecret string `json:"accessKeySecret"`
|
||
|
||
// Bucket 存储桶名称
|
||
Bucket string `json:"bucket"`
|
||
|
||
// Region 区域
|
||
Region string `json:"region"`
|
||
|
||
// UseSSL 是否使用SSL
|
||
UseSSL bool `json:"useSSL"`
|
||
|
||
// Domain 自定义域名(CDN域名)
|
||
Domain string `json:"domain"`
|
||
}
|
||
|
||
// RedisConfig Redis配置
|
||
type RedisConfig struct {
|
||
// Host Redis主机
|
||
Host string `json:"host"`
|
||
|
||
// Port Redis端口
|
||
Port int `json:"port"`
|
||
|
||
// Password Redis密码
|
||
Password string `json:"password"`
|
||
|
||
// Database Redis数据库编号
|
||
Database int `json:"database"`
|
||
|
||
// MaxRetries 最大重试次数
|
||
MaxRetries int `json:"maxRetries"`
|
||
|
||
// PoolSize 连接池大小
|
||
PoolSize int `json:"poolSize"`
|
||
|
||
// MinIdleConns 最小空闲连接数
|
||
MinIdleConns int `json:"minIdleConns"`
|
||
|
||
// DialTimeout 连接超时时间(秒)
|
||
DialTimeout int `json:"dialTimeout"`
|
||
|
||
// ReadTimeout 读取超时时间(秒)
|
||
ReadTimeout int `json:"readTimeout"`
|
||
|
||
// WriteTimeout 写入超时时间(秒)
|
||
WriteTimeout int `json:"writeTimeout"`
|
||
|
||
// Addr Redis地址(如果设置了,会优先使用,格式: host:port)
|
||
Addr string `json:"addr"`
|
||
}
|
||
|
||
// CORSConfig CORS配置(与middleware.CORSConfig兼容)
|
||
type CORSConfig struct {
|
||
// AllowedOrigins 允许的源
|
||
AllowedOrigins []string `json:"allowedOrigins"`
|
||
|
||
// AllowedMethods 允许的HTTP方法
|
||
AllowedMethods []string `json:"allowedMethods"`
|
||
|
||
// AllowedHeaders 允许的请求头
|
||
AllowedHeaders []string `json:"allowedHeaders"`
|
||
|
||
// ExposedHeaders 暴露给客户端的响应头
|
||
ExposedHeaders []string `json:"exposedHeaders"`
|
||
|
||
// AllowCredentials 是否允许发送凭证
|
||
AllowCredentials bool `json:"allowCredentials"`
|
||
|
||
// MaxAge 预检请求的缓存时间(秒)
|
||
MaxAge int `json:"maxAge"`
|
||
}
|
||
|
||
// MinIOConfig MinIO配置
|
||
type MinIOConfig struct {
|
||
// Endpoint MinIO端点地址
|
||
Endpoint string `json:"endpoint"`
|
||
|
||
// AccessKeyID 访问密钥ID
|
||
AccessKeyID string `json:"accessKeyId"`
|
||
|
||
// SecretAccessKey 密钥
|
||
SecretAccessKey string `json:"secretAccessKey"`
|
||
|
||
// UseSSL 是否使用SSL
|
||
UseSSL bool `json:"useSSL"`
|
||
|
||
// Bucket 存储桶名称
|
||
Bucket string `json:"bucket"`
|
||
|
||
// Region 区域
|
||
Region string `json:"region"`
|
||
|
||
// Domain 自定义域名
|
||
Domain string `json:"domain"`
|
||
}
|
||
|
||
// EmailConfig 邮件配置
|
||
type EmailConfig struct {
|
||
// Host SMTP服务器地址
|
||
Host string `json:"host"`
|
||
|
||
// Port SMTP服务器端口
|
||
Port int `json:"port"`
|
||
|
||
// Username 发件人邮箱
|
||
Username string `json:"username"`
|
||
|
||
// Password 邮箱密码或授权码
|
||
Password string `json:"password"`
|
||
|
||
// From 发件人邮箱地址(如果为空,使用Username)
|
||
From string `json:"from"`
|
||
|
||
// FromName 发件人名称
|
||
FromName string `json:"fromName"`
|
||
|
||
// UseTLS 是否使用TLS
|
||
UseTLS bool `json:"useTLS"`
|
||
|
||
// UseSSL 是否使用SSL
|
||
UseSSL bool `json:"useSSL"`
|
||
|
||
// Timeout 连接超时时间(秒)
|
||
Timeout int `json:"timeout"`
|
||
}
|
||
|
||
// SMSConfig 短信配置(阿里云短信)
|
||
type SMSConfig struct {
|
||
// AccessKeyID 阿里云AccessKey ID
|
||
AccessKeyID string `json:"accessKeyId"`
|
||
|
||
// AccessKeySecret 阿里云AccessKey Secret
|
||
AccessKeySecret string `json:"accessKeySecret"`
|
||
|
||
// Region 区域(如:cn-hangzhou)
|
||
Region string `json:"region"`
|
||
|
||
// SignName 短信签名
|
||
SignName string `json:"signName"`
|
||
|
||
// TemplateCode 短信模板代码
|
||
TemplateCode string `json:"templateCode"`
|
||
|
||
// Endpoint 服务端点(可选,默认使用区域端点)
|
||
Endpoint string `json:"endpoint"`
|
||
|
||
// Timeout 请求超时时间(秒)
|
||
Timeout int `json:"timeout"`
|
||
}
|
||
|
||
// LoggerConfig 日志配置
|
||
type LoggerConfig struct {
|
||
// Level 日志级别: debug, info, warn, error
|
||
Level string `json:"level"`
|
||
|
||
// Output 输出方式: stdout, stderr, file, both
|
||
Output string `json:"output"`
|
||
|
||
// FilePath 日志文件路径(当output为file或both时必需)
|
||
FilePath string `json:"filePath"`
|
||
|
||
// Prefix 日志前缀
|
||
Prefix string `json:"prefix"`
|
||
|
||
// DisableTimestamp 禁用时间戳
|
||
DisableTimestamp bool `json:"disableTimestamp"`
|
||
|
||
// Async 是否使用异步模式(默认false,即同步模式)
|
||
// 异步模式:日志写入通过channel异步处理,不阻塞调用方
|
||
// 同步模式:日志直接写入,会阻塞调用方直到写入完成
|
||
Async bool `json:"async"`
|
||
|
||
// BufferSize 异步模式下的缓冲区大小(默认1000)
|
||
// 当缓冲区满时,新的日志会阻塞直到有空间
|
||
BufferSize int `json:"bufferSize"`
|
||
}
|
||
|
||
// LoadFromFile 从文件加载配置
|
||
// filePath: 配置文件路径(支持绝对路径和相对路径)
|
||
func LoadFromFile(filePath string) (*Config, error) {
|
||
// 转换为绝对路径
|
||
absPath, err := filepath.Abs(filePath)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get absolute path: %w", err)
|
||
}
|
||
|
||
// 读取文件
|
||
data, err := os.ReadFile(absPath)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||
}
|
||
|
||
// 解析JSON
|
||
var config Config
|
||
if err := json.Unmarshal(data, &config); err != nil {
|
||
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
||
}
|
||
|
||
// 设置默认值
|
||
config.setDefaults()
|
||
|
||
return &config, nil
|
||
}
|
||
|
||
// LoadFromBytes 从字节数组加载配置
|
||
func LoadFromBytes(data []byte) (*Config, error) {
|
||
var config Config
|
||
if err := json.Unmarshal(data, &config); err != nil {
|
||
return nil, fmt.Errorf("failed to parse config: %w", err)
|
||
}
|
||
|
||
// 设置默认值
|
||
config.setDefaults()
|
||
|
||
return &config, nil
|
||
}
|
||
|
||
// setDefaults 设置默认值
|
||
func (c *Config) setDefaults() {
|
||
// 数据库默认值
|
||
if c.Database != nil {
|
||
if c.Database.Charset == "" {
|
||
c.Database.Charset = "utf8mb4"
|
||
}
|
||
if c.Database.MaxOpenConns == 0 {
|
||
c.Database.MaxOpenConns = 100
|
||
}
|
||
if c.Database.MaxIdleConns == 0 {
|
||
c.Database.MaxIdleConns = 10
|
||
}
|
||
if c.Database.ConnMaxLifetime == 0 {
|
||
c.Database.ConnMaxLifetime = 3600
|
||
}
|
||
}
|
||
|
||
// Redis默认值
|
||
if c.Redis != nil {
|
||
if c.Redis.Port == 0 {
|
||
c.Redis.Port = 6379
|
||
}
|
||
if c.Redis.Database == 0 {
|
||
c.Redis.Database = 0
|
||
}
|
||
if c.Redis.MaxRetries == 0 {
|
||
c.Redis.MaxRetries = 3
|
||
}
|
||
if c.Redis.PoolSize == 0 {
|
||
c.Redis.PoolSize = 10
|
||
}
|
||
if c.Redis.MinIdleConns == 0 {
|
||
c.Redis.MinIdleConns = 5
|
||
}
|
||
if c.Redis.DialTimeout == 0 {
|
||
c.Redis.DialTimeout = 5
|
||
}
|
||
if c.Redis.ReadTimeout == 0 {
|
||
c.Redis.ReadTimeout = 3
|
||
}
|
||
if c.Redis.WriteTimeout == 0 {
|
||
c.Redis.WriteTimeout = 3
|
||
}
|
||
}
|
||
|
||
// CORS默认值
|
||
if c.CORS != nil {
|
||
if len(c.CORS.AllowedOrigins) == 0 {
|
||
c.CORS.AllowedOrigins = []string{"*"}
|
||
}
|
||
if len(c.CORS.AllowedMethods) == 0 {
|
||
c.CORS.AllowedMethods = []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"}
|
||
}
|
||
if len(c.CORS.AllowedHeaders) == 0 {
|
||
c.CORS.AllowedHeaders = []string{"Content-Type", "Authorization", "X-Requested-With", "X-Timezone"}
|
||
}
|
||
if c.CORS.MaxAge == 0 {
|
||
c.CORS.MaxAge = 86400
|
||
}
|
||
}
|
||
|
||
// 邮件默认值
|
||
if c.Email != nil {
|
||
if c.Email.Port == 0 {
|
||
c.Email.Port = 587 // 默认使用587端口(TLS)
|
||
}
|
||
if c.Email.From == "" {
|
||
c.Email.From = c.Email.Username
|
||
}
|
||
if c.Email.Timeout == 0 {
|
||
c.Email.Timeout = 30
|
||
}
|
||
}
|
||
|
||
// 短信默认值
|
||
if c.SMS != nil {
|
||
if c.SMS.Region == "" {
|
||
c.SMS.Region = "cn-hangzhou"
|
||
}
|
||
if c.SMS.Timeout == 0 {
|
||
c.SMS.Timeout = 10
|
||
}
|
||
}
|
||
|
||
// 日志默认值
|
||
if c.Logger != nil {
|
||
if c.Logger.Level == "" {
|
||
c.Logger.Level = "info"
|
||
}
|
||
if c.Logger.Output == "" {
|
||
c.Logger.Output = "stdout"
|
||
}
|
||
}
|
||
}
|
||
|
||
// GetDatabase 获取数据库配置
|
||
func (c *Config) GetDatabase() *DatabaseConfig {
|
||
return c.Database
|
||
}
|
||
|
||
// GetOSS 获取OSS配置
|
||
func (c *Config) GetOSS() *OSSConfig {
|
||
return c.OSS
|
||
}
|
||
|
||
// GetRedis 获取Redis配置
|
||
func (c *Config) GetRedis() *RedisConfig {
|
||
return c.Redis
|
||
}
|
||
|
||
// GetCORS 获取CORS配置
|
||
// 返回的是 config.CORSConfig,需要转换为 middleware.CORSConfig 时
|
||
// 可以使用 middleware.CORSFromConfig() 函数
|
||
func (c *Config) GetCORS() *CORSConfig {
|
||
return c.CORS
|
||
}
|
||
|
||
// GetMinIO 获取MinIO配置
|
||
func (c *Config) GetMinIO() *MinIOConfig {
|
||
return c.MinIO
|
||
}
|
||
|
||
// GetEmail 获取邮件配置
|
||
func (c *Config) GetEmail() *EmailConfig {
|
||
return c.Email
|
||
}
|
||
|
||
// GetSMS 获取短信配置
|
||
func (c *Config) GetSMS() *SMSConfig {
|
||
return c.SMS
|
||
}
|
||
|
||
// GetLogger 获取日志配置
|
||
func (c *Config) GetLogger() *LoggerConfig {
|
||
return c.Logger
|
||
}
|
||
|
||
// GetDatabaseDSN 获取数据库连接字符串
|
||
func (c *Config) GetDatabaseDSN() (string, error) {
|
||
if c.Database == nil {
|
||
return "", fmt.Errorf("database config is nil")
|
||
}
|
||
|
||
// 如果已经设置了DSN,直接返回
|
||
if c.Database.DSN != "" {
|
||
return c.Database.DSN, nil
|
||
}
|
||
|
||
// 根据数据库类型生成DSN
|
||
switch c.Database.Type {
|
||
case "mysql":
|
||
return c.buildMySQLDSN(), nil
|
||
case "postgres":
|
||
return c.buildPostgresDSN(), nil
|
||
case "sqlite":
|
||
return c.Database.Database, nil
|
||
default:
|
||
return "", fmt.Errorf("unsupported database type: %s", c.Database.Type)
|
||
}
|
||
}
|
||
|
||
// buildMySQLDSN 构建MySQL连接字符串
|
||
// 注意:数据库时间统一使用UTC时间,不设置时区
|
||
func (c *Config) buildMySQLDSN() string {
|
||
db := c.Database
|
||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", db.User, db.Password, db.Host, db.Port, db.Database)
|
||
|
||
params := []string{}
|
||
if db.Charset != "" {
|
||
params = append(params, "charset="+db.Charset)
|
||
}
|
||
params = append(params, "parseTime=True")
|
||
params = append(params, "loc=UTC") // 统一使用UTC时区
|
||
|
||
if len(params) > 0 {
|
||
dsn += "?" + params[0]
|
||
for i := 1; i < len(params); i++ {
|
||
dsn += "&" + params[i]
|
||
}
|
||
}
|
||
|
||
return dsn
|
||
}
|
||
|
||
// buildPostgresDSN 构建PostgreSQL连接字符串
|
||
// 注意:数据库时间统一使用UTC时间,不设置时区
|
||
func (c *Config) buildPostgresDSN() string {
|
||
db := c.Database
|
||
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s",
|
||
db.Host, db.Port, db.User, db.Password, db.Database)
|
||
|
||
// 统一使用UTC时区
|
||
dsn += " timezone=UTC"
|
||
dsn += " sslmode=disable"
|
||
return dsn
|
||
}
|
||
|
||
// GetRedisAddr 获取Redis地址
|
||
func (c *Config) GetRedisAddr() string {
|
||
if c.Redis == nil {
|
||
return ""
|
||
}
|
||
|
||
// 如果已经设置了Addr,直接返回
|
||
if c.Redis.Addr != "" {
|
||
return c.Redis.Addr
|
||
}
|
||
|
||
// 构建地址
|
||
return fmt.Sprintf("%s:%d", c.Redis.Host, c.Redis.Port)
|
||
}
|