调整工具类的方法,优化方法调用及增加迁移工具及其用法

This commit is contained in:
2025-12-04 22:30:48 +08:00
parent de8fc13f18
commit 0650feb0d2
28 changed files with 3753 additions and 162 deletions

286
middleware/ratelimit.go Normal file
View File

@@ -0,0 +1,286 @@
package middleware
import (
"net/http"
"sync"
"time"
)
// RateLimiter 限流器接口
type RateLimiter interface {
// Allow 检查是否允许请求
// key: 限流键如IP地址、用户ID等
// 返回: 是否允许, 剩余配额, 重置时间
Allow(key string) (allowed bool, remaining int, resetTime time.Time)
}
// tokenBucketLimiter 令牌桶限流器
type tokenBucketLimiter struct {
rate int // 每个窗口期允许的请求数
windowSize time.Duration // 窗口大小
buckets map[string]*bucket
mu sync.RWMutex
cleanupTicker *time.Ticker
stopCleanup chan struct{}
}
// bucket 令牌桶
type bucket struct {
tokens int // 当前令牌数
lastRefill time.Time // 上次填充时间
mu sync.Mutex
}
// NewTokenBucketLimiter 创建令牌桶限流器
func NewTokenBucketLimiter(rate int, windowSize time.Duration) RateLimiter {
limiter := &tokenBucketLimiter{
rate: rate,
windowSize: windowSize,
buckets: make(map[string]*bucket),
stopCleanup: make(chan struct{}),
}
// 启动清理goroutine定期清理过期的bucket
limiter.cleanupTicker = time.NewTicker(windowSize * 2)
go limiter.cleanup()
return limiter
}
// cleanup 定期清理过期的bucket
func (l *tokenBucketLimiter) cleanup() {
for {
select {
case <-l.cleanupTicker.C:
l.mu.Lock()
now := time.Now()
for key, bkt := range l.buckets {
bkt.mu.Lock()
// 如果bucket超过2个窗口期没有使用删除它
if now.Sub(bkt.lastRefill) > l.windowSize*2 {
delete(l.buckets, key)
}
bkt.mu.Unlock()
}
l.mu.Unlock()
case <-l.stopCleanup:
l.cleanupTicker.Stop()
return
}
}
}
// Allow 检查是否允许请求
func (l *tokenBucketLimiter) Allow(key string) (bool, int, time.Time) {
now := time.Now()
// 获取或创建bucket
l.mu.Lock()
bkt, exists := l.buckets[key]
if !exists {
bkt = &bucket{
tokens: l.rate,
lastRefill: now,
}
l.buckets[key] = bkt
}
l.mu.Unlock()
// 尝试消费令牌
bkt.mu.Lock()
defer bkt.mu.Unlock()
// 计算需要填充的令牌数
elapsed := now.Sub(bkt.lastRefill)
if elapsed >= l.windowSize {
// 窗口期已过,重新填充
bkt.tokens = l.rate
bkt.lastRefill = now
}
// 检查是否有可用令牌
if bkt.tokens > 0 {
bkt.tokens--
resetTime := bkt.lastRefill.Add(l.windowSize)
return true, bkt.tokens, resetTime
}
// 没有可用令牌
resetTime := bkt.lastRefill.Add(l.windowSize)
return false, 0, resetTime
}
// RateLimitConfig 限流中间件配置
type RateLimitConfig struct {
// Limiter 限流器(必需)
// 如果为nil会使用默认的令牌桶限流器100请求/分钟)
Limiter RateLimiter
// KeyFunc 生成限流键的函数(可选)
// 默认使用客户端IP作为键
// 可以自定义为用户ID、API Key等
KeyFunc func(r *http.Request) string
// OnRateLimitExceeded 当限流被触发时的回调(可选)
// 可以用于记录日志、发送告警等
OnRateLimitExceeded func(w http.ResponseWriter, r *http.Request, key string)
}
// RateLimit 限流中间件
// 实现基于令牌桶算法的请求限流
//
// 使用方式1使用默认配置100请求/分钟按IP限流
//
// chain := middleware.NewChain(
// middleware.RateLimit(nil),
// )
//
// 使用方式2自定义限流规则
//
// limiter := middleware.NewTokenBucketLimiter(10, time.Minute) // 10请求/分钟
// chain := middleware.NewChain(
// middleware.RateLimit(&middleware.RateLimitConfig{
// Limiter: limiter,
// }),
// )
//
// 使用方式3按用户ID限流
//
// chain := middleware.NewChain(
// middleware.RateLimit(&middleware.RateLimitConfig{
// Limiter: limiter,
// KeyFunc: func(r *http.Request) string {
// // 从请求头或token中获取用户ID
// return r.Header.Get("X-User-ID")
// },
// }),
// )
func RateLimit(config *RateLimitConfig) func(http.Handler) http.Handler {
// 如果没有配置,使用默认配置
if config == nil {
config = &RateLimitConfig{}
}
// 如果没有提供限流器创建默认的100请求/分钟)
if config.Limiter == nil {
config.Limiter = NewTokenBucketLimiter(100, time.Minute)
}
// 如果没有提供KeyFunc使用默认的客户端IP
if config.KeyFunc == nil {
config.KeyFunc = func(r *http.Request) string {
return getClientIP(r)
}
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 生成限流键
key := config.KeyFunc(r)
if key == "" {
// 如果无法生成键,允许请求通过
next.ServeHTTP(w, r)
return
}
// 检查是否允许请求
allowed, remaining, resetTime := config.Limiter.Allow(key)
// 设置限流相关的响应头
w.Header().Set("X-RateLimit-Limit", formatInt(config.Limiter.(*tokenBucketLimiter).rate))
w.Header().Set("X-RateLimit-Remaining", formatInt(remaining))
w.Header().Set("X-RateLimit-Reset", formatInt64(resetTime.Unix()))
if !allowed {
// 触发限流回调
if config.OnRateLimitExceeded != nil {
config.OnRateLimitExceeded(w, r, key)
}
// 返回429错误
w.Header().Set("Retry-After", formatInt64(int64(time.Until(resetTime).Seconds())))
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
return
}
// 允许请求通过
next.ServeHTTP(w, r)
})
}
}
// RateLimitWithRate 使用指定速率创建限流中间件(便捷函数)
// rate: 每个窗口期允许的请求数
// windowSize: 窗口大小
func RateLimitWithRate(rate int, windowSize time.Duration) func(http.Handler) http.Handler {
return RateLimit(&RateLimitConfig{
Limiter: NewTokenBucketLimiter(rate, windowSize),
})
}
// RateLimitByIP 按IP限流便捷函数
func RateLimitByIP(rate int, windowSize time.Duration) func(http.Handler) http.Handler {
return RateLimit(&RateLimitConfig{
Limiter: NewTokenBucketLimiter(rate, windowSize),
KeyFunc: func(r *http.Request) string {
return getClientIP(r)
},
})
}
// formatInt 格式化int为字符串
func formatInt(n int) string {
if n == 0 {
return "0"
}
// 简单的int转字符串
var buf [20]byte
i := len(buf) - 1
negative := n < 0
if negative {
n = -n
}
for n > 0 {
buf[i] = byte('0' + n%10)
n /= 10
i--
}
if negative {
buf[i] = '-'
i--
}
return string(buf[i+1:])
}
// formatInt64 格式化int64为字符串
func formatInt64(n int64) string {
if n == 0 {
return "0"
}
// 简单的int64转字符串
var buf [20]byte
i := len(buf) - 1
negative := n < 0
if negative {
n = -n
}
for n > 0 {
buf[i] = byte('0' + n%10)
n /= 10
i--
}
if negative {
buf[i] = '-'
i--
}
return string(buf[i+1:])
}