package middleware import ( "net/http" "strconv" "strings" ) // CORSConfig CORS配置 type CORSConfig struct { // AllowedOrigins 允许的源,支持通配符 "*" 表示允许所有源 // 例如: []string{"*"} 或 []string{"https://example.com", "https://app.example.com"} AllowedOrigins []string // AllowedMethods 允许的HTTP方法 // 默认: []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"} AllowedMethods []string // AllowedHeaders 允许的请求头 // 默认: []string{"Content-Type", "Authorization", "X-Requested-With", "X-Timezone"} AllowedHeaders []string // ExposedHeaders 暴露给客户端的响应头 ExposedHeaders []string // AllowCredentials 是否允许发送凭证(cookies等) AllowCredentials bool // MaxAge 预检请求的缓存时间(秒) // 默认: 86400 (24小时) MaxAge int } // DefaultCORSConfig 返回默认的CORS配置 func DefaultCORSConfig() *CORSConfig { return &CORSConfig{ AllowedOrigins: []string{"*"}, AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"}, AllowedHeaders: []string{"Content-Type", "Authorization", "X-Requested-With", "X-Timezone"}, ExposedHeaders: []string{}, AllowCredentials: false, MaxAge: 86400, } } // CORS CORS中间件 func CORS(config ...*CORSConfig) func(http.Handler) http.Handler { var cfg *CORSConfig if len(config) > 0 && config[0] != nil { cfg = config[0] } else { cfg = DefaultCORSConfig() } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") // 处理预检请求 if r.Method == http.MethodOptions { handlePreflight(w, r, cfg, origin) return } // 处理实际请求 handleCORSHeaders(w, r, cfg, origin) next.ServeHTTP(w, r) }) } } // handlePreflight 处理预检请求 func handlePreflight(w http.ResponseWriter, r *http.Request, cfg *CORSConfig, origin string) { // 检查源是否允许 if !isOriginAllowed(origin, cfg.AllowedOrigins) { w.WriteHeader(http.StatusForbidden) return } // 设置CORS响应头 setCORSHeaders(w, cfg, origin) // 检查请求的方法是否允许 requestMethod := r.Header.Get("Access-Control-Request-Method") if requestMethod != "" && !isMethodAllowed(requestMethod, cfg.AllowedMethods) { w.WriteHeader(http.StatusMethodNotAllowed) return } // 检查请求头是否允许 requestHeaders := r.Header.Get("Access-Control-Request-Headers") if requestHeaders != "" { headers := strings.Split(strings.ToLower(requestHeaders), ",") for _, header := range headers { header = strings.TrimSpace(header) if !isHeaderAllowed(header, cfg.AllowedHeaders) { w.WriteHeader(http.StatusForbidden) return } } } w.WriteHeader(http.StatusNoContent) } // handleCORSHeaders 处理实际请求的CORS头 func handleCORSHeaders(w http.ResponseWriter, r *http.Request, cfg *CORSConfig, origin string) { // 检查源是否允许 if !isOriginAllowed(origin, cfg.AllowedOrigins) { return } // 设置CORS响应头 setCORSHeaders(w, cfg, origin) } // setCORSHeaders 设置CORS响应头 func setCORSHeaders(w http.ResponseWriter, cfg *CORSConfig, origin string) { // Access-Control-Allow-Origin if len(cfg.AllowedOrigins) == 1 && cfg.AllowedOrigins[0] == "*" { if !cfg.AllowCredentials { w.Header().Set("Access-Control-Allow-Origin", "*") } else { // 如果允许凭证,不能使用 "*",必须返回具体的源 if origin != "" { w.Header().Set("Access-Control-Allow-Origin", origin) } } } else { if isOriginAllowed(origin, cfg.AllowedOrigins) { w.Header().Set("Access-Control-Allow-Origin", origin) } } // Access-Control-Allow-Methods if len(cfg.AllowedMethods) > 0 { w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.AllowedMethods, ", ")) } // Access-Control-Allow-Headers if len(cfg.AllowedHeaders) > 0 { w.Header().Set("Access-Control-Allow-Headers", strings.Join(cfg.AllowedHeaders, ", ")) } // Access-Control-Expose-Headers if len(cfg.ExposedHeaders) > 0 { w.Header().Set("Access-Control-Expose-Headers", strings.Join(cfg.ExposedHeaders, ", ")) } // Access-Control-Allow-Credentials if cfg.AllowCredentials { w.Header().Set("Access-Control-Allow-Credentials", "true") } // Access-Control-Max-Age if cfg.MaxAge > 0 { w.Header().Set("Access-Control-Max-Age", strconv.Itoa(cfg.MaxAge)) } } // isOriginAllowed 检查源是否允许 func isOriginAllowed(origin string, allowedOrigins []string) bool { if origin == "" { return false } for _, allowed := range allowedOrigins { if allowed == "*" { return true } if allowed == origin { return true } // 支持简单的通配符匹配(如 "*.example.com") if strings.HasPrefix(allowed, "*.") { domain := strings.TrimPrefix(allowed, "*.") if strings.HasSuffix(origin, domain) { return true } } } return false } // isMethodAllowed 检查方法是否允许 func isMethodAllowed(method string, allowedMethods []string) bool { method = strings.ToUpper(strings.TrimSpace(method)) for _, allowed := range allowedMethods { if strings.ToUpper(allowed) == method { return true } } return false } // isHeaderAllowed 检查请求头是否允许 func isHeaderAllowed(header string, allowedHeaders []string) bool { header = strings.ToLower(strings.TrimSpace(header)) for _, allowed := range allowedHeaders { if strings.ToLower(allowed) == header { return true } } return false }