初始版本,工具基础类

This commit is contained in:
2025-11-30 13:02:34 +08:00
commit ea4e2e305d
37 changed files with 7480 additions and 0 deletions

36
middleware/chain.go Normal file
View File

@@ -0,0 +1,36 @@
package middleware
import "net/http"
// Chain 中间件链
type Chain struct {
middlewares []func(http.Handler) http.Handler
}
// NewChain 创建新的中间件链
func NewChain(middlewares ...func(http.Handler) http.Handler) *Chain {
return &Chain{
middlewares: middlewares,
}
}
// Then 将中间件链应用到处理器
func (c *Chain) Then(handler http.Handler) http.Handler {
final := handler
for i := len(c.middlewares) - 1; i >= 0; i-- {
final = c.middlewares[i](final)
}
return final
}
// ThenFunc 将中间件链应用到处理器函数
func (c *Chain) ThenFunc(handler http.HandlerFunc) http.Handler {
return c.Then(handler)
}
// Append 追加中间件
func (c *Chain) Append(middlewares ...func(http.Handler) http.Handler) *Chain {
c.middlewares = append(c.middlewares, middlewares...)
return c
}

207
middleware/cors.go Normal file
View File

@@ -0,0 +1,207 @@
package middleware
import (
"net/http"
"strconv"
"strings"
)
// CORSConfig CORS配置
type CORSConfig struct {
// AllowedOrigins 允许的源,支持通配符 "*" 表示允许所有源
// 例如: []string{"*"} 或 []string{"https://example.com", "https://app.example.com"}
AllowedOrigins []string
// AllowedMethods 允许的HTTP方法
// 默认: []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"}
AllowedMethods []string
// AllowedHeaders 允许的请求头
// 默认: []string{"Content-Type", "Authorization", "X-Requested-With", "X-Timezone"}
AllowedHeaders []string
// ExposedHeaders 暴露给客户端的响应头
ExposedHeaders []string
// AllowCredentials 是否允许发送凭证cookies等
AllowCredentials bool
// MaxAge 预检请求的缓存时间(秒)
// 默认: 86400 (24小时)
MaxAge int
}
// DefaultCORSConfig 返回默认的CORS配置
func DefaultCORSConfig() *CORSConfig {
return &CORSConfig{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"},
AllowedHeaders: []string{"Content-Type", "Authorization", "X-Requested-With", "X-Timezone"},
ExposedHeaders: []string{},
AllowCredentials: false,
MaxAge: 86400,
}
}
// CORS CORS中间件
func CORS(config ...*CORSConfig) func(http.Handler) http.Handler {
var cfg *CORSConfig
if len(config) > 0 && config[0] != nil {
cfg = config[0]
} else {
cfg = DefaultCORSConfig()
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
// 处理预检请求
if r.Method == http.MethodOptions {
handlePreflight(w, r, cfg, origin)
return
}
// 处理实际请求
handleCORSHeaders(w, r, cfg, origin)
next.ServeHTTP(w, r)
})
}
}
// handlePreflight 处理预检请求
func handlePreflight(w http.ResponseWriter, r *http.Request, cfg *CORSConfig, origin string) {
// 检查源是否允许
if !isOriginAllowed(origin, cfg.AllowedOrigins) {
w.WriteHeader(http.StatusForbidden)
return
}
// 设置CORS响应头
setCORSHeaders(w, cfg, origin)
// 检查请求的方法是否允许
requestMethod := r.Header.Get("Access-Control-Request-Method")
if requestMethod != "" && !isMethodAllowed(requestMethod, cfg.AllowedMethods) {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
// 检查请求头是否允许
requestHeaders := r.Header.Get("Access-Control-Request-Headers")
if requestHeaders != "" {
headers := strings.Split(strings.ToLower(requestHeaders), ",")
for _, header := range headers {
header = strings.TrimSpace(header)
if !isHeaderAllowed(header, cfg.AllowedHeaders) {
w.WriteHeader(http.StatusForbidden)
return
}
}
}
w.WriteHeader(http.StatusNoContent)
}
// handleCORSHeaders 处理实际请求的CORS头
func handleCORSHeaders(w http.ResponseWriter, r *http.Request, cfg *CORSConfig, origin string) {
// 检查源是否允许
if !isOriginAllowed(origin, cfg.AllowedOrigins) {
return
}
// 设置CORS响应头
setCORSHeaders(w, cfg, origin)
}
// setCORSHeaders 设置CORS响应头
func setCORSHeaders(w http.ResponseWriter, cfg *CORSConfig, origin string) {
// Access-Control-Allow-Origin
if len(cfg.AllowedOrigins) == 1 && cfg.AllowedOrigins[0] == "*" {
if !cfg.AllowCredentials {
w.Header().Set("Access-Control-Allow-Origin", "*")
} else {
// 如果允许凭证,不能使用 "*",必须返回具体的源
if origin != "" {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
}
} else {
if isOriginAllowed(origin, cfg.AllowedOrigins) {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
}
// Access-Control-Allow-Methods
if len(cfg.AllowedMethods) > 0 {
w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.AllowedMethods, ", "))
}
// Access-Control-Allow-Headers
if len(cfg.AllowedHeaders) > 0 {
w.Header().Set("Access-Control-Allow-Headers", strings.Join(cfg.AllowedHeaders, ", "))
}
// Access-Control-Expose-Headers
if len(cfg.ExposedHeaders) > 0 {
w.Header().Set("Access-Control-Expose-Headers", strings.Join(cfg.ExposedHeaders, ", "))
}
// Access-Control-Allow-Credentials
if cfg.AllowCredentials {
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
// Access-Control-Max-Age
if cfg.MaxAge > 0 {
w.Header().Set("Access-Control-Max-Age", strconv.Itoa(cfg.MaxAge))
}
}
// isOriginAllowed 检查源是否允许
func isOriginAllowed(origin string, allowedOrigins []string) bool {
if origin == "" {
return false
}
for _, allowed := range allowedOrigins {
if allowed == "*" {
return true
}
if allowed == origin {
return true
}
// 支持简单的通配符匹配(如 "*.example.com"
if strings.HasPrefix(allowed, "*.") {
domain := strings.TrimPrefix(allowed, "*.")
if strings.HasSuffix(origin, domain) {
return true
}
}
}
return false
}
// isMethodAllowed 检查方法是否允许
func isMethodAllowed(method string, allowedMethods []string) bool {
method = strings.ToUpper(strings.TrimSpace(method))
for _, allowed := range allowedMethods {
if strings.ToUpper(allowed) == method {
return true
}
}
return false
}
// isHeaderAllowed 检查请求头是否允许
func isHeaderAllowed(header string, allowedHeaders []string) bool {
header = strings.ToLower(strings.TrimSpace(header))
for _, allowed := range allowedHeaders {
if strings.ToLower(allowed) == header {
return true
}
}
return false
}

82
middleware/timezone.go Normal file
View File

@@ -0,0 +1,82 @@
package middleware
import (
"context"
"net/http"
"github.com/go-common/datetime"
)
// TimezoneKey context中存储时区的key
type timezoneKey struct{}
// TimezoneHeaderName 时区请求头名称
const TimezoneHeaderName = "X-Timezone"
// DefaultTimezone 默认时区
const DefaultTimezone = datetime.AsiaShanghai
// GetTimezoneFromContext 从context中获取时区
func GetTimezoneFromContext(ctx context.Context) string {
if tz, ok := ctx.Value(timezoneKey{}).(string); ok && tz != "" {
return tz
}
return DefaultTimezone
}
// Timezone 时区处理中间件
// 从请求头 X-Timezone 读取时区信息,如果未传递则使用默认时区 AsiaShanghai
// 时区信息会存储到context中可以通过 GetTimezoneFromContext 获取
func Timezone(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 从请求头获取时区
timezone := r.Header.Get(TimezoneHeaderName)
// 如果未传递时区信息,使用默认时区
if timezone == "" {
timezone = DefaultTimezone
}
// 验证时区是否有效
if _, err := datetime.GetLocation(timezone); err != nil {
// 如果时区无效,使用默认时区
timezone = DefaultTimezone
}
// 将时区存储到context中
ctx := context.WithValue(r.Context(), timezoneKey{}, timezone)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// TimezoneWithDefault 时区处理中间件(可自定义默认时区)
// defaultTimezone: 默认时区,如果未指定则使用 AsiaShanghai
func TimezoneWithDefault(defaultTimezone string) func(http.Handler) http.Handler {
// 验证默认时区是否有效
if _, err := datetime.GetLocation(defaultTimezone); err != nil {
defaultTimezone = DefaultTimezone
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 从请求头获取时区
timezone := r.Header.Get(TimezoneHeaderName)
// 如果未传递时区信息,使用指定的默认时区
if timezone == "" {
timezone = defaultTimezone
}
// 验证时区是否有效
if _, err := datetime.GetLocation(timezone); err != nil {
// 如果时区无效,使用默认时区
timezone = defaultTimezone
}
// 将时区存储到context中
ctx := context.WithValue(r.Context(), timezoneKey{}, timezone)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}