90 lines
2.8 KiB
Go
90 lines
2.8 KiB
Go
package main
|
||
|
||
import (
|
||
"log"
|
||
"net/http"
|
||
"time"
|
||
|
||
commonhttp "git.toowon.com/jimmy/go-common/http"
|
||
"git.toowon.com/jimmy/go-common/middleware"
|
||
)
|
||
|
||
// 示例:限流中间件的使用
|
||
// 展示不同的限流策略
|
||
func main() {
|
||
// 策略1:按IP限流(10请求/分钟)
|
||
ipLimitChain := middleware.NewChain(
|
||
middleware.RateLimitByIP(10, time.Minute),
|
||
)
|
||
|
||
// 策略2:按用户ID限流(100请求/分钟)
|
||
limiter := middleware.NewTokenBucketLimiter(100, time.Minute)
|
||
userLimitConfig := &middleware.RateLimitConfig{
|
||
Limiter: limiter,
|
||
KeyFunc: func(r *http.Request) string {
|
||
// 从请求头获取用户ID
|
||
userID := r.Header.Get("X-User-ID")
|
||
if userID != "" {
|
||
return "user:" + userID
|
||
}
|
||
// 没有用户ID则使用IP
|
||
return "ip:" + r.RemoteAddr
|
||
},
|
||
OnRateLimitExceeded: func(w http.ResponseWriter, r *http.Request, key string) {
|
||
log.Printf("Rate limit exceeded for key: %s, path: %s", key, r.URL.Path)
|
||
},
|
||
}
|
||
userLimitChain := middleware.NewChain(
|
||
middleware.RateLimit(userLimitConfig),
|
||
)
|
||
|
||
// 路由1:按IP限流的API(严格限制)
|
||
http.Handle("/api/public", ipLimitChain.ThenFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
h := commonhttp.NewHandler(w, r)
|
||
h.Success(map[string]interface{}{
|
||
"message": "Public API - IP rate limited (10/min)",
|
||
"tip": "Try refreshing quickly to see rate limiting",
|
||
})
|
||
}))
|
||
|
||
// 路由2:按用户ID限流的API(宽松限制)
|
||
http.Handle("/api/private", userLimitChain.ThenFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
h := commonhttp.NewHandler(w, r)
|
||
userID := r.Header.Get("X-User-ID")
|
||
if userID == "" {
|
||
h.Error(401, "Missing X-User-ID header")
|
||
return
|
||
}
|
||
h.Success(map[string]interface{}{
|
||
"message": "Private API - User rate limited (100/min)",
|
||
"user_id": userID,
|
||
})
|
||
}))
|
||
|
||
// 路由3:无限流的API(用于测试对比)
|
||
http.HandleFunc("/api/unlimited", func(w http.ResponseWriter, r *http.Request) {
|
||
h := commonhttp.NewHandler(w, r)
|
||
h.Success(map[string]interface{}{
|
||
"message": "Unlimited API - No rate limiting",
|
||
})
|
||
})
|
||
|
||
// 启动服务器
|
||
addr := ":8080"
|
||
log.Printf("Server starting on %s", addr)
|
||
log.Println("Test endpoints:")
|
||
log.Printf(" - IP limited (10/min): http://localhost%s/api/public", addr)
|
||
log.Printf(" - User limited (100/min): http://localhost%s/api/private (add X-User-ID header)", addr)
|
||
log.Printf(" - No limit: http://localhost%s/api/unlimited", addr)
|
||
log.Println("\nTest with curl:")
|
||
log.Printf(" curl http://localhost%s/api/public", addr)
|
||
log.Printf(" curl -H 'X-User-ID: user123' http://localhost%s/api/private", addr)
|
||
log.Println("\nResponse headers:")
|
||
log.Println(" X-RateLimit-Limit: Total allowed requests")
|
||
log.Println(" X-RateLimit-Remaining: Remaining requests")
|
||
log.Println(" X-RateLimit-Reset: Reset timestamp")
|
||
|
||
log.Fatal(http.ListenAndServe(addr, nil))
|
||
}
|
||
|