Files
go-common/middleware/cors.go

238 lines
6.3 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"net/http"
"strconv"
"strings"
)
// CORSConfig CORS配置
type CORSConfig struct {
// AllowedOrigins 允许的源,支持通配符 "*" 表示允许所有源
// 例如: []string{"*"} 或 []string{"https://example.com", "https://app.example.com"}
AllowedOrigins []string
// AllowedMethods 允许的HTTP方法
// 默认: []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"}
AllowedMethods []string
// AllowedHeaders 允许的请求头
// 默认: []string{"Content-Type", "Authorization", "X-Requested-With", "X-Timezone"}
AllowedHeaders []string
// ExposedHeaders 暴露给客户端的响应头
ExposedHeaders []string
// AllowCredentials 是否允许发送凭证cookies等
AllowCredentials bool
// MaxAge 预检请求的缓存时间(秒)
// 默认: 86400 (24小时)
MaxAge int
}
// DefaultCORSConfig 返回默认的CORS配置
func DefaultCORSConfig() *CORSConfig {
return &CORSConfig{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"},
AllowedHeaders: []string{"Content-Type", "Authorization", "X-Requested-With", "X-Timezone"},
ExposedHeaders: []string{},
AllowCredentials: false,
MaxAge: 86400,
}
}
// NewCORSConfig 从配置参数创建 CORSConfig
// 用于从 config 包的 CORSConfig 转换为 middleware 的 CORSConfig
// 避免循环依赖
func NewCORSConfig(allowedOrigins, allowedMethods, allowedHeaders, exposedHeaders []string, allowCredentials bool, maxAge int) *CORSConfig {
cfg := &CORSConfig{
AllowedOrigins: allowedOrigins,
AllowedMethods: allowedMethods,
AllowedHeaders: allowedHeaders,
ExposedHeaders: exposedHeaders,
AllowCredentials: allowCredentials,
MaxAge: maxAge,
}
// 设置默认值(如果为空)
if len(cfg.AllowedOrigins) == 0 {
cfg.AllowedOrigins = []string{"*"}
}
if len(cfg.AllowedMethods) == 0 {
cfg.AllowedMethods = []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"}
}
if len(cfg.AllowedHeaders) == 0 {
cfg.AllowedHeaders = []string{"Content-Type", "Authorization", "X-Requested-With", "X-Timezone"}
}
if cfg.MaxAge == 0 {
cfg.MaxAge = 86400
}
return cfg
}
// CORS CORS中间件
func CORS(config ...*CORSConfig) func(http.Handler) http.Handler {
var cfg *CORSConfig
if len(config) > 0 && config[0] != nil {
cfg = config[0]
} else {
cfg = DefaultCORSConfig()
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
// 处理预检请求
if r.Method == http.MethodOptions {
handlePreflight(w, r, cfg, origin)
return
}
// 处理实际请求
handleCORSHeaders(w, r, cfg, origin)
next.ServeHTTP(w, r)
})
}
}
// handlePreflight 处理预检请求
func handlePreflight(w http.ResponseWriter, r *http.Request, cfg *CORSConfig, origin string) {
// 检查源是否允许
if !isOriginAllowed(origin, cfg.AllowedOrigins) {
w.WriteHeader(http.StatusForbidden)
return
}
// 设置CORS响应头
setCORSHeaders(w, cfg, origin)
// 检查请求的方法是否允许
requestMethod := r.Header.Get("Access-Control-Request-Method")
if requestMethod != "" && !isMethodAllowed(requestMethod, cfg.AllowedMethods) {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
// 检查请求头是否允许
requestHeaders := r.Header.Get("Access-Control-Request-Headers")
if requestHeaders != "" {
headers := strings.Split(strings.ToLower(requestHeaders), ",")
for _, header := range headers {
header = strings.TrimSpace(header)
if !isHeaderAllowed(header, cfg.AllowedHeaders) {
w.WriteHeader(http.StatusForbidden)
return
}
}
}
w.WriteHeader(http.StatusNoContent)
}
// handleCORSHeaders 处理实际请求的CORS头
func handleCORSHeaders(w http.ResponseWriter, r *http.Request, cfg *CORSConfig, origin string) {
// 检查源是否允许
if !isOriginAllowed(origin, cfg.AllowedOrigins) {
return
}
// 设置CORS响应头
setCORSHeaders(w, cfg, origin)
}
// setCORSHeaders 设置CORS响应头
func setCORSHeaders(w http.ResponseWriter, cfg *CORSConfig, origin string) {
// Access-Control-Allow-Origin
if len(cfg.AllowedOrigins) == 1 && cfg.AllowedOrigins[0] == "*" {
if !cfg.AllowCredentials {
w.Header().Set("Access-Control-Allow-Origin", "*")
} else {
// 如果允许凭证,不能使用 "*",必须返回具体的源
if origin != "" {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
}
} else {
if isOriginAllowed(origin, cfg.AllowedOrigins) {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
}
// Access-Control-Allow-Methods
if len(cfg.AllowedMethods) > 0 {
w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.AllowedMethods, ", "))
}
// Access-Control-Allow-Headers
if len(cfg.AllowedHeaders) > 0 {
w.Header().Set("Access-Control-Allow-Headers", strings.Join(cfg.AllowedHeaders, ", "))
}
// Access-Control-Expose-Headers
if len(cfg.ExposedHeaders) > 0 {
w.Header().Set("Access-Control-Expose-Headers", strings.Join(cfg.ExposedHeaders, ", "))
}
// Access-Control-Allow-Credentials
if cfg.AllowCredentials {
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
// Access-Control-Max-Age
if cfg.MaxAge > 0 {
w.Header().Set("Access-Control-Max-Age", strconv.Itoa(cfg.MaxAge))
}
}
// isOriginAllowed 检查源是否允许
func isOriginAllowed(origin string, allowedOrigins []string) bool {
if origin == "" {
return false
}
for _, allowed := range allowedOrigins {
if allowed == "*" {
return true
}
if allowed == origin {
return true
}
// 支持简单的通配符匹配(如 "*.example.com"
if strings.HasPrefix(allowed, "*.") {
domain := strings.TrimPrefix(allowed, "*.")
if strings.HasSuffix(origin, domain) {
return true
}
}
}
return false
}
// isMethodAllowed 检查方法是否允许
func isMethodAllowed(method string, allowedMethods []string) bool {
method = strings.ToUpper(strings.TrimSpace(method))
for _, allowed := range allowedMethods {
if strings.ToUpper(allowed) == method {
return true
}
}
return false
}
// isHeaderAllowed 检查请求头是否允许
func isHeaderAllowed(header string, allowedHeaders []string) bool {
header = strings.ToLower(strings.TrimSpace(header))
for _, allowed := range allowedHeaders {
if strings.ToLower(allowed) == header {
return true
}
}
return false
}