package middleware import ( "net/http" "sync" "time" ) // RateLimiter 限流器接口 type RateLimiter interface { // Allow 检查是否允许请求 // key: 限流键(如IP地址、用户ID等) // 返回: 是否允许, 剩余配额, 重置时间 Allow(key string) (allowed bool, remaining int, resetTime time.Time) } // tokenBucketLimiter 令牌桶限流器 type tokenBucketLimiter struct { rate int // 每个窗口期允许的请求数 windowSize time.Duration // 窗口大小 buckets map[string]*bucket mu sync.RWMutex cleanupTicker *time.Ticker stopCleanup chan struct{} } // bucket 令牌桶 type bucket struct { tokens int // 当前令牌数 lastRefill time.Time // 上次填充时间 mu sync.Mutex } // NewTokenBucketLimiter 创建令牌桶限流器 func NewTokenBucketLimiter(rate int, windowSize time.Duration) RateLimiter { limiter := &tokenBucketLimiter{ rate: rate, windowSize: windowSize, buckets: make(map[string]*bucket), stopCleanup: make(chan struct{}), } // 启动清理goroutine,定期清理过期的bucket limiter.cleanupTicker = time.NewTicker(windowSize * 2) go limiter.cleanup() return limiter } // cleanup 定期清理过期的bucket func (l *tokenBucketLimiter) cleanup() { for { select { case <-l.cleanupTicker.C: l.mu.Lock() now := time.Now() for key, bkt := range l.buckets { bkt.mu.Lock() // 如果bucket超过2个窗口期没有使用,删除它 if now.Sub(bkt.lastRefill) > l.windowSize*2 { delete(l.buckets, key) } bkt.mu.Unlock() } l.mu.Unlock() case <-l.stopCleanup: l.cleanupTicker.Stop() return } } } // Allow 检查是否允许请求 func (l *tokenBucketLimiter) Allow(key string) (bool, int, time.Time) { now := time.Now() // 获取或创建bucket l.mu.Lock() bkt, exists := l.buckets[key] if !exists { bkt = &bucket{ tokens: l.rate, lastRefill: now, } l.buckets[key] = bkt } l.mu.Unlock() // 尝试消费令牌 bkt.mu.Lock() defer bkt.mu.Unlock() // 计算需要填充的令牌数 elapsed := now.Sub(bkt.lastRefill) if elapsed >= l.windowSize { // 窗口期已过,重新填充 bkt.tokens = l.rate bkt.lastRefill = now } // 检查是否有可用令牌 if bkt.tokens > 0 { bkt.tokens-- resetTime := bkt.lastRefill.Add(l.windowSize) return true, bkt.tokens, resetTime } // 没有可用令牌 resetTime := bkt.lastRefill.Add(l.windowSize) return false, 0, resetTime } // RateLimitConfig 限流中间件配置 type RateLimitConfig struct { // Limiter 限流器(必需) // 如果为nil,会使用默认的令牌桶限流器(100请求/分钟) Limiter RateLimiter // KeyFunc 生成限流键的函数(可选) // 默认使用客户端IP作为键 // 可以自定义为用户ID、API Key等 KeyFunc func(r *http.Request) string // OnRateLimitExceeded 当限流被触发时的回调(可选) // 可以用于记录日志、发送告警等 OnRateLimitExceeded func(w http.ResponseWriter, r *http.Request, key string) } // RateLimit 限流中间件 // 实现基于令牌桶算法的请求限流 // // 使用方式1:使用默认配置(100请求/分钟,按IP限流) // // chain := middleware.NewChain( // middleware.RateLimit(nil), // ) // // 使用方式2:自定义限流规则 // // limiter := middleware.NewTokenBucketLimiter(10, time.Minute) // 10请求/分钟 // chain := middleware.NewChain( // middleware.RateLimit(&middleware.RateLimitConfig{ // Limiter: limiter, // }), // ) // // 使用方式3:按用户ID限流 // // chain := middleware.NewChain( // middleware.RateLimit(&middleware.RateLimitConfig{ // Limiter: limiter, // KeyFunc: func(r *http.Request) string { // // 从请求头或token中获取用户ID // return r.Header.Get("X-User-ID") // }, // }), // ) func RateLimit(config *RateLimitConfig) func(http.Handler) http.Handler { // 如果没有配置,使用默认配置 if config == nil { config = &RateLimitConfig{} } // 如果没有提供限流器,创建默认的(100请求/分钟) if config.Limiter == nil { config.Limiter = NewTokenBucketLimiter(100, time.Minute) } // 如果没有提供KeyFunc,使用默认的(客户端IP) if config.KeyFunc == nil { config.KeyFunc = func(r *http.Request) string { return getClientIP(r) } } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // 生成限流键 key := config.KeyFunc(r) if key == "" { // 如果无法生成键,允许请求通过 next.ServeHTTP(w, r) return } // 检查是否允许请求 allowed, remaining, resetTime := config.Limiter.Allow(key) // 设置限流相关的响应头 w.Header().Set("X-RateLimit-Limit", formatInt(config.Limiter.(*tokenBucketLimiter).rate)) w.Header().Set("X-RateLimit-Remaining", formatInt(remaining)) w.Header().Set("X-RateLimit-Reset", formatInt64(resetTime.Unix())) if !allowed { // 触发限流回调 if config.OnRateLimitExceeded != nil { config.OnRateLimitExceeded(w, r, key) } // 返回429错误 w.Header().Set("Retry-After", formatInt64(int64(time.Until(resetTime).Seconds()))) http.Error(w, "Too Many Requests", http.StatusTooManyRequests) return } // 允许请求通过 next.ServeHTTP(w, r) }) } } // RateLimitWithRate 使用指定速率创建限流中间件(便捷函数) // rate: 每个窗口期允许的请求数 // windowSize: 窗口大小 func RateLimitWithRate(rate int, windowSize time.Duration) func(http.Handler) http.Handler { return RateLimit(&RateLimitConfig{ Limiter: NewTokenBucketLimiter(rate, windowSize), }) } // RateLimitByIP 按IP限流(便捷函数) func RateLimitByIP(rate int, windowSize time.Duration) func(http.Handler) http.Handler { return RateLimit(&RateLimitConfig{ Limiter: NewTokenBucketLimiter(rate, windowSize), KeyFunc: func(r *http.Request) string { return getClientIP(r) }, }) } // formatInt 格式化int为字符串 func formatInt(n int) string { if n == 0 { return "0" } // 简单的int转字符串 var buf [20]byte i := len(buf) - 1 negative := n < 0 if negative { n = -n } for n > 0 { buf[i] = byte('0' + n%10) n /= 10 i-- } if negative { buf[i] = '-' i-- } return string(buf[i+1:]) } // formatInt64 格式化int64为字符串 func formatInt64(n int64) string { if n == 0 { return "0" } // 简单的int64转字符串 var buf [20]byte i := len(buf) - 1 negative := n < 0 if negative { n = -n } for n > 0 { buf[i] = byte('0' + n%10) n /= 10 i-- } if negative { buf[i] = '-' i-- } return string(buf[i+1:]) }