package middleware import ( "net/http" "strconv" "time" "git.toowon.com/jimmy/go-common/logger" ) // responseWriter 包装 http.ResponseWriter 以捕获状态码和响应大小 type responseWriter struct { http.ResponseWriter statusCode int size int } func (rw *responseWriter) WriteHeader(statusCode int) { rw.statusCode = statusCode rw.ResponseWriter.WriteHeader(statusCode) } func (rw *responseWriter) Write(b []byte) (int, error) { size, err := rw.ResponseWriter.Write(b) rw.size += size return size, err } // LoggingConfig 日志中间件配置 type LoggingConfig struct { // Logger 日志记录器(可选,如果为nil则使用默认logger) Logger *logger.Logger // SkipPaths 跳过记录的路径列表(如健康检查接口) SkipPaths []string // LogRequestBody 是否记录请求体(谨慎使用,可能影响性能) LogRequestBody bool // LogResponseBody 是否记录响应体(谨慎使用,可能影响性能和内存) LogResponseBody bool } // Logging HTTP请求日志中间件 // 记录每个HTTP请求的详细信息,包括: // - 请求方法、路径、IP、User-Agent // - 响应状态码、响应大小 // - 请求处理时间 // // 使用方式1:使用默认logger // // chain := middleware.NewChain( // middleware.Logging(nil), // ) // // 使用方式2:使用自定义logger // // myLogger, _ := logger.NewLogger(loggerConfig) // chain := middleware.NewChain( // middleware.Logging(&middleware.LoggingConfig{ // Logger: myLogger, // SkipPaths: []string{"/health", "/metrics"}, // }), // ) func Logging(config *LoggingConfig) func(http.Handler) http.Handler { // 如果没有配置,使用默认配置 if config == nil { config = &LoggingConfig{} } // 如果没有提供logger,创建一个默认的 if config.Logger == nil { // 使用默认配置创建logger(输出到stdout,info级别) defaultLogger, err := logger.NewLogger(nil) if err != nil { // 如果创建失败,使用nil,后面会降级处理 config.Logger = nil } else { config.Logger = defaultLogger } } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // 检查是否跳过此路径 if shouldSkipPath(r.URL.Path, config.SkipPaths) { next.ServeHTTP(w, r) return } // 记录开始时间 startTime := time.Now() // 包装 ResponseWriter 以捕获状态码和响应大小 rw := &responseWriter{ ResponseWriter: w, statusCode: http.StatusOK, // 默认200 size: 0, } // 处理请求 next.ServeHTTP(rw, r) // 计算处理时间 duration := time.Since(startTime) // 记录日志 logHTTPRequest(config.Logger, r, rw, duration) }) } } // shouldSkipPath 检查是否应该跳过该路径 func shouldSkipPath(path string, skipPaths []string) bool { for _, skipPath := range skipPaths { if path == skipPath { return true } } return false } // logHTTPRequest 记录HTTP请求日志 func logHTTPRequest(log *logger.Logger, r *http.Request, rw *responseWriter, duration time.Duration) { // 获取客户端IP clientIP := GetClientIP(r) // 构建日志字段 fields := map[string]interface{}{ "method": r.Method, "path": r.URL.Path, "query": r.URL.RawQuery, "status": rw.statusCode, "size": rw.size, "duration": duration.Milliseconds(), // 毫秒 "ip": clientIP, "user_agent": r.UserAgent(), "referer": r.Referer(), } // 构建日志消息 message := "HTTP Request" // 根据状态码选择日志级别 if log != nil { // 使用提供的logger if rw.statusCode >= 500 { log.Errorf(fields, message) } else if rw.statusCode >= 400 { log.Warnf(fields, message) } else { log.Infof(fields, message) } } else { // 降级处理:使用标准输出 // 注意:这是同步的,不会有性能问题 if rw.statusCode >= 500 { logToStdout("ERROR", fields, message) } else if rw.statusCode >= 400 { logToStdout("WARN", fields, message) } else { logToStdout("INFO", fields, message) } } } // logToStdout 降级处理:输出到标准输出(当logger不可用时) func logToStdout(level string, fields map[string]interface{}, message string) { // 简单的标准输出日志 var fieldStr string for k, v := range fields { fieldStr += " " + k + "=" + formatValue(v) } println("[" + level + "] " + message + fieldStr) } // formatValue 格式化值(用于日志输出) func formatValue(v interface{}) string { switch val := v.(type) { case string: return val case int: return strconv.Itoa(val) case int64: return strconv.FormatInt(val, 10) default: return "" } } // GetClientIP 获取客户端真实IP // 优先级:X-Forwarded-For > X-Real-IP > RemoteAddr func GetClientIP(r *http.Request) string { // 尝试从 X-Forwarded-For 获取 xff := r.Header.Get("X-Forwarded-For") if xff != "" { // X-Forwarded-For 可能包含多个IP,取第一个 for idx := 0; idx < len(xff); idx++ { if xff[idx] == ',' { return xff[:idx] } } return xff } // 尝试从 X-Real-IP 获取 xri := r.Header.Get("X-Real-IP") if xri != "" { return xri } // 使用 RemoteAddr remoteAddr := r.RemoteAddr // 移除端口号 for i := len(remoteAddr) - 1; i >= 0; i-- { if remoteAddr[i] == ':' { return remoteAddr[:i] } } return remoteAddr }