Files
go-common/middleware/recovery.go
2025-12-05 00:07:15 +08:00

150 lines
3.9 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"fmt"
"net/http"
"runtime/debug"
"git.toowon.com/jimmy/go-common/logger"
)
// RecoveryConfig Recovery中间件配置
type RecoveryConfig struct {
// Logger 日志记录器可选如果为nil则使用默认logger
Logger *logger.Logger
// EnableStackTrace 是否在日志中包含堆栈跟踪
EnableStackTrace bool
// CustomHandler 自定义错误处理函数(可选)
// 如果设置了,会在记录日志后调用此函数
// 可以用于自定义错误响应格式
CustomHandler func(w http.ResponseWriter, r *http.Request, err interface{})
}
// Recovery Panic恢复中间件
// 捕获HTTP处理过程中的panic记录错误日志并返回500错误
// 防止panic导致整个服务崩溃
//
// 使用方式1使用默认配置
//
// chain := middleware.NewChain(
// middleware.Recovery(nil),
// )
//
// 使用方式2使用自定义配置
//
// myLogger, _ := logger.NewLogger(loggerConfig)
// chain := middleware.NewChain(
// middleware.Recovery(&middleware.RecoveryConfig{
// Logger: myLogger,
// EnableStackTrace: true,
// }),
// )
//
// 使用方式3自定义错误响应
//
// chain := middleware.NewChain(
// middleware.Recovery(&middleware.RecoveryConfig{
// Logger: myLogger,
// CustomHandler: func(w http.ResponseWriter, r *http.Request, err interface{}) {
// // 自定义JSON响应
// w.Header().Set("Content-Type", "application/json")
// w.WriteHeader(http.StatusInternalServerError)
// w.Write([]byte(`{"code":500,"message":"Internal Server Error"}`))
// },
// }),
// )
func Recovery(config *RecoveryConfig) func(http.Handler) http.Handler {
// 如果没有配置,使用默认配置
if config == nil {
config = &RecoveryConfig{
EnableStackTrace: true, // 默认启用堆栈跟踪
}
}
// 如果没有提供logger创建一个默认的
if config.Logger == nil {
defaultLogger, err := logger.NewLogger(nil)
if err != nil {
config.Logger = nil
} else {
config.Logger = defaultLogger
}
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
// 记录panic信息
logPanic(config.Logger, r, err, config.EnableStackTrace)
// 如果提供了自定义处理函数,调用它
if config.CustomHandler != nil {
config.CustomHandler(w, r, err)
return
}
// 默认错误响应
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}
}
// logPanic 记录panic日志
func logPanic(log *logger.Logger, r *http.Request, err interface{}, enableStackTrace bool) {
// 获取堆栈跟踪
var stack string
if enableStackTrace {
stack = string(debug.Stack())
}
// 构建日志字段
fields := map[string]interface{}{
"method": r.Method,
"path": r.URL.Path,
"query": r.URL.RawQuery,
"ip": GetClientIP(r),
"error": fmt.Sprintf("%v", err),
}
// 构建日志消息
message := "Panic recovered"
if enableStackTrace && stack != "" {
message += "\n" + stack
}
// 记录错误日志
if log != nil {
log.Errorf(fields, message)
} else {
// 降级处理:输出到标准错误
fmt.Printf("[ERROR] %s\n", message)
for k, v := range fields {
fmt.Printf(" %s: %v\n", k, v)
}
}
}
// RecoveryWithLogger 使用指定logger的Recovery中间件便捷函数
func RecoveryWithLogger(log *logger.Logger) func(http.Handler) http.Handler {
return Recovery(&RecoveryConfig{
Logger: log,
EnableStackTrace: true,
})
}
// RecoveryWithCustomHandler 使用自定义错误处理的Recovery中间件便捷函数
func RecoveryWithCustomHandler(customHandler func(w http.ResponseWriter, r *http.Request, err interface{})) func(http.Handler) http.Handler {
return Recovery(&RecoveryConfig{
EnableStackTrace: true,
CustomHandler: customHandler,
})
}