238 lines
6.3 KiB
Go
238 lines
6.3 KiB
Go
package middleware
|
||
|
||
import (
|
||
"net/http"
|
||
"strconv"
|
||
"strings"
|
||
)
|
||
|
||
// CORSConfig CORS配置
|
||
type CORSConfig struct {
|
||
// AllowedOrigins 允许的源,支持通配符 "*" 表示允许所有源
|
||
// 例如: []string{"*"} 或 []string{"https://example.com", "https://app.example.com"}
|
||
AllowedOrigins []string
|
||
|
||
// AllowedMethods 允许的HTTP方法
|
||
// 默认: []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"}
|
||
AllowedMethods []string
|
||
|
||
// AllowedHeaders 允许的请求头
|
||
// 默认: []string{"Content-Type", "Authorization", "X-Requested-With", "X-Timezone"}
|
||
AllowedHeaders []string
|
||
|
||
// ExposedHeaders 暴露给客户端的响应头
|
||
ExposedHeaders []string
|
||
|
||
// AllowCredentials 是否允许发送凭证(cookies等)
|
||
AllowCredentials bool
|
||
|
||
// MaxAge 预检请求的缓存时间(秒)
|
||
// 默认: 86400 (24小时)
|
||
MaxAge int
|
||
}
|
||
|
||
// DefaultCORSConfig 返回默认的CORS配置
|
||
func DefaultCORSConfig() *CORSConfig {
|
||
return &CORSConfig{
|
||
AllowedOrigins: []string{"*"},
|
||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"},
|
||
AllowedHeaders: []string{"Content-Type", "Authorization", "X-Requested-With", "X-Timezone"},
|
||
ExposedHeaders: []string{},
|
||
AllowCredentials: false,
|
||
MaxAge: 86400,
|
||
}
|
||
}
|
||
|
||
// NewCORSConfig 从配置参数创建 CORSConfig
|
||
// 用于从 config 包的 CORSConfig 转换为 middleware 的 CORSConfig
|
||
// 避免循环依赖
|
||
func NewCORSConfig(allowedOrigins, allowedMethods, allowedHeaders, exposedHeaders []string, allowCredentials bool, maxAge int) *CORSConfig {
|
||
cfg := &CORSConfig{
|
||
AllowedOrigins: allowedOrigins,
|
||
AllowedMethods: allowedMethods,
|
||
AllowedHeaders: allowedHeaders,
|
||
ExposedHeaders: exposedHeaders,
|
||
AllowCredentials: allowCredentials,
|
||
MaxAge: maxAge,
|
||
}
|
||
|
||
// 设置默认值(如果为空)
|
||
if len(cfg.AllowedOrigins) == 0 {
|
||
cfg.AllowedOrigins = []string{"*"}
|
||
}
|
||
if len(cfg.AllowedMethods) == 0 {
|
||
cfg.AllowedMethods = []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"}
|
||
}
|
||
if len(cfg.AllowedHeaders) == 0 {
|
||
cfg.AllowedHeaders = []string{"Content-Type", "Authorization", "X-Requested-With", "X-Timezone"}
|
||
}
|
||
if cfg.MaxAge == 0 {
|
||
cfg.MaxAge = 86400
|
||
}
|
||
|
||
return cfg
|
||
}
|
||
|
||
// CORS CORS中间件
|
||
func CORS(config ...*CORSConfig) func(http.Handler) http.Handler {
|
||
var cfg *CORSConfig
|
||
if len(config) > 0 && config[0] != nil {
|
||
cfg = config[0]
|
||
} else {
|
||
cfg = DefaultCORSConfig()
|
||
}
|
||
|
||
return func(next http.Handler) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
origin := r.Header.Get("Origin")
|
||
|
||
// 处理预检请求
|
||
if r.Method == http.MethodOptions {
|
||
handlePreflight(w, r, cfg, origin)
|
||
return
|
||
}
|
||
|
||
// 处理实际请求
|
||
handleCORSHeaders(w, r, cfg, origin)
|
||
next.ServeHTTP(w, r)
|
||
})
|
||
}
|
||
}
|
||
|
||
// handlePreflight 处理预检请求
|
||
func handlePreflight(w http.ResponseWriter, r *http.Request, cfg *CORSConfig, origin string) {
|
||
// 检查源是否允许
|
||
if !isOriginAllowed(origin, cfg.AllowedOrigins) {
|
||
w.WriteHeader(http.StatusForbidden)
|
||
return
|
||
}
|
||
|
||
// 设置CORS响应头
|
||
setCORSHeaders(w, cfg, origin)
|
||
|
||
// 检查请求的方法是否允许
|
||
requestMethod := r.Header.Get("Access-Control-Request-Method")
|
||
if requestMethod != "" && !isMethodAllowed(requestMethod, cfg.AllowedMethods) {
|
||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 检查请求头是否允许
|
||
requestHeaders := r.Header.Get("Access-Control-Request-Headers")
|
||
if requestHeaders != "" {
|
||
headers := strings.Split(strings.ToLower(requestHeaders), ",")
|
||
for _, header := range headers {
|
||
header = strings.TrimSpace(header)
|
||
if !isHeaderAllowed(header, cfg.AllowedHeaders) {
|
||
w.WriteHeader(http.StatusForbidden)
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
w.WriteHeader(http.StatusNoContent)
|
||
}
|
||
|
||
// handleCORSHeaders 处理实际请求的CORS头
|
||
func handleCORSHeaders(w http.ResponseWriter, r *http.Request, cfg *CORSConfig, origin string) {
|
||
// 检查源是否允许
|
||
if !isOriginAllowed(origin, cfg.AllowedOrigins) {
|
||
return
|
||
}
|
||
|
||
// 设置CORS响应头
|
||
setCORSHeaders(w, cfg, origin)
|
||
}
|
||
|
||
// setCORSHeaders 设置CORS响应头
|
||
func setCORSHeaders(w http.ResponseWriter, cfg *CORSConfig, origin string) {
|
||
// Access-Control-Allow-Origin
|
||
if len(cfg.AllowedOrigins) == 1 && cfg.AllowedOrigins[0] == "*" {
|
||
if !cfg.AllowCredentials {
|
||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||
} else {
|
||
// 如果允许凭证,不能使用 "*",必须返回具体的源
|
||
if origin != "" {
|
||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||
}
|
||
}
|
||
} else {
|
||
if isOriginAllowed(origin, cfg.AllowedOrigins) {
|
||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||
}
|
||
}
|
||
|
||
// Access-Control-Allow-Methods
|
||
if len(cfg.AllowedMethods) > 0 {
|
||
w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.AllowedMethods, ", "))
|
||
}
|
||
|
||
// Access-Control-Allow-Headers
|
||
if len(cfg.AllowedHeaders) > 0 {
|
||
w.Header().Set("Access-Control-Allow-Headers", strings.Join(cfg.AllowedHeaders, ", "))
|
||
}
|
||
|
||
// Access-Control-Expose-Headers
|
||
if len(cfg.ExposedHeaders) > 0 {
|
||
w.Header().Set("Access-Control-Expose-Headers", strings.Join(cfg.ExposedHeaders, ", "))
|
||
}
|
||
|
||
// Access-Control-Allow-Credentials
|
||
if cfg.AllowCredentials {
|
||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||
}
|
||
|
||
// Access-Control-Max-Age
|
||
if cfg.MaxAge > 0 {
|
||
w.Header().Set("Access-Control-Max-Age", strconv.Itoa(cfg.MaxAge))
|
||
}
|
||
}
|
||
|
||
// isOriginAllowed 检查源是否允许
|
||
func isOriginAllowed(origin string, allowedOrigins []string) bool {
|
||
if origin == "" {
|
||
return false
|
||
}
|
||
|
||
for _, allowed := range allowedOrigins {
|
||
if allowed == "*" {
|
||
return true
|
||
}
|
||
if allowed == origin {
|
||
return true
|
||
}
|
||
// 支持简单的通配符匹配(如 "*.example.com")
|
||
if strings.HasPrefix(allowed, "*.") {
|
||
domain := strings.TrimPrefix(allowed, "*.")
|
||
if strings.HasSuffix(origin, domain) {
|
||
return true
|
||
}
|
||
}
|
||
}
|
||
|
||
return false
|
||
}
|
||
|
||
// isMethodAllowed 检查方法是否允许
|
||
func isMethodAllowed(method string, allowedMethods []string) bool {
|
||
method = strings.ToUpper(strings.TrimSpace(method))
|
||
for _, allowed := range allowedMethods {
|
||
if strings.ToUpper(allowed) == method {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// isHeaderAllowed 检查请求头是否允许
|
||
func isHeaderAllowed(header string, allowedHeaders []string) bool {
|
||
header = strings.ToLower(strings.TrimSpace(header))
|
||
for _, allowed := range allowedHeaders {
|
||
if strings.ToLower(allowed) == header {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|