150 lines
3.9 KiB
Go
150 lines
3.9 KiB
Go
package middleware
|
||
|
||
import (
|
||
"fmt"
|
||
"net/http"
|
||
"runtime/debug"
|
||
|
||
"git.toowon.com/jimmy/go-common/logger"
|
||
)
|
||
|
||
// RecoveryConfig Recovery中间件配置
|
||
type RecoveryConfig struct {
|
||
// Logger 日志记录器(可选,如果为nil则使用默认logger)
|
||
Logger *logger.Logger
|
||
|
||
// EnableStackTrace 是否在日志中包含堆栈跟踪
|
||
EnableStackTrace bool
|
||
|
||
// CustomHandler 自定义错误处理函数(可选)
|
||
// 如果设置了,会在记录日志后调用此函数
|
||
// 可以用于自定义错误响应格式
|
||
CustomHandler func(w http.ResponseWriter, r *http.Request, err interface{})
|
||
}
|
||
|
||
// Recovery Panic恢复中间件
|
||
// 捕获HTTP处理过程中的panic,记录错误日志并返回500错误
|
||
// 防止panic导致整个服务崩溃
|
||
//
|
||
// 使用方式1:使用默认配置
|
||
//
|
||
// chain := middleware.NewChain(
|
||
// middleware.Recovery(nil),
|
||
// )
|
||
//
|
||
// 使用方式2:使用自定义配置
|
||
//
|
||
// myLogger, _ := logger.NewLogger(loggerConfig)
|
||
// chain := middleware.NewChain(
|
||
// middleware.Recovery(&middleware.RecoveryConfig{
|
||
// Logger: myLogger,
|
||
// EnableStackTrace: true,
|
||
// }),
|
||
// )
|
||
//
|
||
// 使用方式3:自定义错误响应
|
||
//
|
||
// chain := middleware.NewChain(
|
||
// middleware.Recovery(&middleware.RecoveryConfig{
|
||
// Logger: myLogger,
|
||
// CustomHandler: func(w http.ResponseWriter, r *http.Request, err interface{}) {
|
||
// // 自定义JSON响应
|
||
// w.Header().Set("Content-Type", "application/json")
|
||
// w.WriteHeader(http.StatusInternalServerError)
|
||
// w.Write([]byte(`{"code":500,"message":"Internal Server Error"}`))
|
||
// },
|
||
// }),
|
||
// )
|
||
func Recovery(config *RecoveryConfig) func(http.Handler) http.Handler {
|
||
// 如果没有配置,使用默认配置
|
||
if config == nil {
|
||
config = &RecoveryConfig{
|
||
EnableStackTrace: true, // 默认启用堆栈跟踪
|
||
}
|
||
}
|
||
|
||
// 如果没有提供logger,创建一个默认的
|
||
if config.Logger == nil {
|
||
defaultLogger, err := logger.NewLogger(nil)
|
||
if err != nil {
|
||
config.Logger = nil
|
||
} else {
|
||
config.Logger = defaultLogger
|
||
}
|
||
}
|
||
|
||
return func(next http.Handler) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
defer func() {
|
||
if err := recover(); err != nil {
|
||
// 记录panic信息
|
||
logPanic(config.Logger, r, err, config.EnableStackTrace)
|
||
|
||
// 如果提供了自定义处理函数,调用它
|
||
if config.CustomHandler != nil {
|
||
config.CustomHandler(w, r, err)
|
||
return
|
||
}
|
||
|
||
// 默认错误响应
|
||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||
}
|
||
}()
|
||
|
||
next.ServeHTTP(w, r)
|
||
})
|
||
}
|
||
}
|
||
|
||
// logPanic 记录panic日志
|
||
func logPanic(log *logger.Logger, r *http.Request, err interface{}, enableStackTrace bool) {
|
||
// 获取堆栈跟踪
|
||
var stack string
|
||
if enableStackTrace {
|
||
stack = string(debug.Stack())
|
||
}
|
||
|
||
// 构建日志字段
|
||
fields := map[string]interface{}{
|
||
"method": r.Method,
|
||
"path": r.URL.Path,
|
||
"query": r.URL.RawQuery,
|
||
"ip": getClientIP(r),
|
||
"error": fmt.Sprintf("%v", err),
|
||
}
|
||
|
||
// 构建日志消息
|
||
message := "Panic recovered"
|
||
if enableStackTrace && stack != "" {
|
||
message += "\n" + stack
|
||
}
|
||
|
||
// 记录错误日志
|
||
if log != nil {
|
||
log.Errorf(fields, message)
|
||
} else {
|
||
// 降级处理:输出到标准错误
|
||
fmt.Printf("[ERROR] %s\n", message)
|
||
for k, v := range fields {
|
||
fmt.Printf(" %s: %v\n", k, v)
|
||
}
|
||
}
|
||
}
|
||
|
||
// RecoveryWithLogger 使用指定logger的Recovery中间件(便捷函数)
|
||
func RecoveryWithLogger(log *logger.Logger) func(http.Handler) http.Handler {
|
||
return Recovery(&RecoveryConfig{
|
||
Logger: log,
|
||
EnableStackTrace: true,
|
||
})
|
||
}
|
||
|
||
// RecoveryWithCustomHandler 使用自定义错误处理的Recovery中间件(便捷函数)
|
||
func RecoveryWithCustomHandler(customHandler func(w http.ResponseWriter, r *http.Request, err interface{})) func(http.Handler) http.Handler {
|
||
return Recovery(&RecoveryConfig{
|
||
EnableStackTrace: true,
|
||
CustomHandler: customHandler,
|
||
})
|
||
}
|
||
|