package middleware import ( "fmt" "net/http" "runtime/debug" "git.toowon.com/jimmy/go-common/logger" ) // RecoveryConfig Recovery中间件配置 type RecoveryConfig struct { // Logger 日志记录器(可选,如果为nil则使用默认logger) Logger *logger.Logger // EnableStackTrace 是否在日志中包含堆栈跟踪 EnableStackTrace bool // CustomHandler 自定义错误处理函数(可选) // 如果设置了,会在记录日志后调用此函数 // 可以用于自定义错误响应格式 CustomHandler func(w http.ResponseWriter, r *http.Request, err interface{}) } // Recovery Panic恢复中间件 // 捕获HTTP处理过程中的panic,记录错误日志并返回500错误 // 防止panic导致整个服务崩溃 // // 使用方式1:使用默认配置 // // chain := middleware.NewChain( // middleware.Recovery(nil), // ) // // 使用方式2:使用自定义配置 // // myLogger, _ := logger.NewLogger(loggerConfig) // chain := middleware.NewChain( // middleware.Recovery(&middleware.RecoveryConfig{ // Logger: myLogger, // EnableStackTrace: true, // }), // ) // // 使用方式3:自定义错误响应 // // chain := middleware.NewChain( // middleware.Recovery(&middleware.RecoveryConfig{ // Logger: myLogger, // CustomHandler: func(w http.ResponseWriter, r *http.Request, err interface{}) { // // 自定义JSON响应 // w.Header().Set("Content-Type", "application/json") // w.WriteHeader(http.StatusInternalServerError) // w.Write([]byte(`{"code":500,"message":"Internal Server Error"}`)) // }, // }), // ) func Recovery(config *RecoveryConfig) func(http.Handler) http.Handler { // 如果没有配置,使用默认配置 if config == nil { config = &RecoveryConfig{ EnableStackTrace: true, // 默认启用堆栈跟踪 } } // 如果没有提供logger,创建一个默认的 if config.Logger == nil { defaultLogger, err := logger.NewLogger(nil) if err != nil { config.Logger = nil } else { config.Logger = defaultLogger } } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer func() { if err := recover(); err != nil { // 记录panic信息 logPanic(config.Logger, r, err, config.EnableStackTrace) // 如果提供了自定义处理函数,调用它 if config.CustomHandler != nil { config.CustomHandler(w, r, err) return } // 默认错误响应 http.Error(w, "Internal Server Error", http.StatusInternalServerError) } }() next.ServeHTTP(w, r) }) } } // logPanic 记录panic日志 func logPanic(log *logger.Logger, r *http.Request, err interface{}, enableStackTrace bool) { // 获取堆栈跟踪 var stack string if enableStackTrace { stack = string(debug.Stack()) } // 构建日志字段 fields := map[string]interface{}{ "method": r.Method, "path": r.URL.Path, "query": r.URL.RawQuery, "ip": getClientIP(r), "error": fmt.Sprintf("%v", err), } // 构建日志消息 message := "Panic recovered" if enableStackTrace && stack != "" { message += "\n" + stack } // 记录错误日志 if log != nil { log.Errorf(fields, message) } else { // 降级处理:输出到标准错误 fmt.Printf("[ERROR] %s\n", message) for k, v := range fields { fmt.Printf(" %s: %v\n", k, v) } } } // RecoveryWithLogger 使用指定logger的Recovery中间件(便捷函数) func RecoveryWithLogger(log *logger.Logger) func(http.Handler) http.Handler { return Recovery(&RecoveryConfig{ Logger: log, EnableStackTrace: true, }) } // RecoveryWithCustomHandler 使用自定义错误处理的Recovery中间件(便捷函数) func RecoveryWithCustomHandler(customHandler func(w http.ResponseWriter, r *http.Request, err interface{})) func(http.Handler) http.Handler { return Recovery(&RecoveryConfig{ EnableStackTrace: true, CustomHandler: customHandler, }) }